@@ -176,7 +176,7 @@ def parse_args():
176
176
parser = argparse .ArgumentParser (description = 'Train keypoints network' )
177
177
# general
178
178
parser .add_argument ('--cfg' , type = str , required = True )
179
- parser .add_argument ('--video_dir ' , type = str , required = True )
179
+ parser .add_argument ('--video_root ' , type = str , required = True )
180
180
parser .add_argument ('--output_dir' , type = str , default = '/output/' )
181
181
parser .add_argument ('--writeBoxFrames' , action = 'store_true' )
182
182
@@ -201,10 +201,13 @@ def get_image_paths(video_path):
201
201
images_list = full_image_paths
202
202
return images_list
203
203
204
+ def get_keypoints_for_video (video_path ):
205
+ return
206
+
204
207
def main ():
208
+ print ("change from pycharm" )
205
209
# transformation
206
210
pose_transform = transforms .Compose ([
207
- transforms .Resize ((224 ,224 )),
208
211
transforms .ToTensor (),
209
212
transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
210
213
std = [0.229 , 0.224 , 0.225 ]),
@@ -217,7 +220,7 @@ def main():
217
220
218
221
args = parse_args ()
219
222
update_config (cfg , args )
220
- video_path = args .video_dir
223
+ video_root = args .video_root
221
224
pose_dir = prepare_output_dirs (args .output_dir )
222
225
csv_output_rows = []
223
226
@@ -237,90 +240,106 @@ def main():
237
240
pose_model .to (CTX )
238
241
pose_model .eval ()
239
242
240
- # Loading an video
241
- # vidcap = cv2.VideoCapture(args.videoFile)
242
- count = 0
243
- for image_path in get_image_paths (video_path ):
244
- print (f"reading image { image_path } " )
245
- image_bgr = cv2 .imread (image_path )
246
- total_now = time .time ()
247
- count += 1
248
-
249
- image_rgb = cv2 .cvtColor (image_bgr , cv2 .COLOR_BGR2RGB )
250
-
251
- # Clone 2 image for person detection and pose estimation
252
- if cfg .DATASET .COLOR_RGB :
253
- image_per = image_rgb .copy ()
254
- image_pose = image_rgb .copy ()
255
- else :
256
- image_per = image_bgr .copy ()
257
- image_pose = image_bgr .copy ()
258
-
259
- # Clone 1 image for debugging purpose
260
- image_debug = image_bgr .copy ()
261
-
262
- # object detection box
263
- now = time .time ()
264
- pred_boxes = get_person_detection_boxes (box_model , image_per , threshold = 0.9 )
265
- then = time .time ()
266
- print ("Find person bbox in: {} sec" .format (then - now ))
267
-
268
- # Can not find people. Move to next frame
269
- if not pred_boxes :
243
+ video_names = os .listdir (video_root )
244
+ video_path_list = [os .path .join (video_root , x ) for x in video_names ]
245
+
246
+ for video_path in video_path_list :
247
+ # Loading an video
248
+ # vidcap = cv2.VideoCapture(args.videoFile)
249
+ count = 0
250
+ tensor_list = []
251
+
252
+ for image_path in get_image_paths (video_path ):
253
+ # print(f"reading image {image_path}")
254
+ image_bgr = cv2 .imread (image_path )
255
+ # print(f"image has shape {image_bgr.shape}")
256
+ total_now = time .time ()
270
257
count += 1
271
- continue
272
258
273
- if args .writeBoxFrames :
259
+ image_rgb = cv2 .cvtColor (image_bgr , cv2 .COLOR_BGR2RGB )
260
+
261
+ # Clone 2 image for person detection and pose estimation
262
+ if cfg .DATASET .COLOR_RGB :
263
+ image_per = image_rgb .copy ()
264
+ image_pose = image_rgb .copy ()
265
+ else :
266
+ image_per = image_bgr .copy ()
267
+ image_pose = image_bgr .copy ()
268
+
269
+ # Clone 1 image for debugging purpose
270
+ image_debug = image_bgr .copy ()
271
+
272
+ # object detection box
273
+ now = time .time ()
274
+ pred_boxes = get_person_detection_boxes (box_model , image_per , threshold = 0.9 )
275
+ then = time .time ()
276
+ # print("Find person bbox in: {} sec".format(then - now))
277
+
278
+ # Can not find people. Move to next frame
279
+ if not pred_boxes :
280
+ count += 1
281
+ continue
282
+
283
+ if args .writeBoxFrames :
284
+ for box in pred_boxes :
285
+ cv2 .rectangle (image_debug , box [0 ], box [1 ], color = (0 , 255 , 0 ),
286
+ thickness = 3 ) # Draw Rectangle with the coordinates
287
+
288
+ # pose estimation : for multiple people
289
+ centers = []
290
+ scales = []
274
291
for box in pred_boxes :
275
- cv2 .rectangle (image_debug , box [0 ], box [1 ], color = (0 , 255 , 0 ),
276
- thickness = 3 ) # Draw Rectangle with the coordinates
277
-
278
- # pose estimation : for multiple people
279
- centers = []
280
- scales = []
281
- for box in pred_boxes :
282
- center , scale = box_to_center_scale (box , cfg .MODEL .IMAGE_SIZE [0 ], cfg .MODEL .IMAGE_SIZE [1 ])
283
- centers .append (center )
284
- scales .append (scale )
285
-
286
- now = time .time ()
287
- pose_preds = get_pose_estimation_prediction (pose_model , image_pose , centers , scales , transform = pose_transform )
288
- then = time .time ()
289
- print ("Find person pose in: {} sec" .format (then - now ))
290
-
291
- new_csv_row = []
292
- for coords in pose_preds :
293
- # Draw each point on image
294
- for coord in coords :
295
- x_coord , y_coord = int (coord [0 ]), int (coord [1 ])
296
- cv2 .circle (image_debug , (x_coord , y_coord ), 4 , (255 , 0 , 0 ), 2 )
297
- new_csv_row .extend ([x_coord , y_coord ])
298
-
299
- total_then = time .time ()
300
-
301
- text = "{:03.2f} sec" .format (total_then - total_now )
302
- cv2 .putText (image_debug , text , (100 , 50 ), cv2 .FONT_HERSHEY_SIMPLEX ,
303
- 1 , (0 , 0 , 255 ), 2 , cv2 .LINE_AA )
304
-
305
- cv2 .imshow ("pos" , image_debug )
306
- if cv2 .waitKey (1 ) & 0xFF == ord ('q' ):
307
- break
308
-
309
- csv_output_rows .append (new_csv_row )
310
- img_file = os .path .join (pose_dir , 'pose_{:08d}.jpg' .format (count ))
311
- cv2 .imwrite (img_file , image_debug )
312
-
313
-
314
- # write csv
315
- csv_headers = ['frame' ]
316
- for keypoint in COCO_KEYPOINT_INDEXES .values ():
317
- csv_headers .extend ([keypoint + '_x' , keypoint + '_y' ])
318
-
319
- csv_output_filename = os .path .join (args .outputDir , 'pose-data.csv' )
320
- with open (csv_output_filename , 'w' , newline = '' ) as csvfile :
321
- csvwriter = csv .writer (csvfile )
322
- csvwriter .writerow (csv_headers )
323
- csvwriter .writerows (csv_output_rows )
292
+ center , scale = box_to_center_scale (box , cfg .MODEL .IMAGE_SIZE [0 ], cfg .MODEL .IMAGE_SIZE [1 ])
293
+ centers .append (center )
294
+ scales .append (scale )
295
+
296
+ now = time .time ()
297
+ pose_preds = get_pose_estimation_prediction (pose_model , image_pose , centers , scales , transform = pose_transform )
298
+ then = time .time ()
299
+ # print("Find person pose in: {} sec".format(then - now))
300
+
301
+ new_csv_row = []
302
+ needed_indices = [0 ,5 ,6 ,7 ,8 ,9 ,10 ]
303
+ needed_points = torch .from_numpy (pose_preds [:, needed_indices , :])
304
+ # print(f"pose pred is {pose_preds.shape}")
305
+ tensor_list .append (needed_points )
306
+ for coords in needed_points :
307
+ # Draw each point on image
308
+ for coord in coords :
309
+ x_coord , y_coord = int (coord [0 ]), int (coord [1 ])
310
+ cv2 .circle (image_debug , (x_coord , y_coord ), 4 , (255 , 0 , 0 ), 2 )
311
+ new_csv_row .extend ([x_coord , y_coord ])
312
+
313
+ total_then = time .time ()
314
+
315
+ # text = "{:03.2f} sec".format(total_then - total_now)
316
+ # cv2.putText(image_debug, text, (100, 50), cv2.FONT_HERSHEY_SIMPLEX,
317
+ # 1, (0, 0, 255), 2, cv2.LINE_AA)
318
+
319
+ # cv2.imshow("pos", image_debug)
320
+ # if cv2.waitKey(1) & 0xFF == ord('q'):
321
+ # break
322
+
323
+ csv_output_rows .append (new_csv_row )
324
+ # img_file = os.path.join(pose_dir, 'pose_{:08d}.jpg'.format(count))
325
+ # cv2.imwrite(img_file, image_debug)
326
+
327
+
328
+ all_point_tensor = torch .cat (tensor_list )
329
+ print (f"all point tensor has shape { all_point_tensor .shape } " )
330
+ video_name = video_path .split ("/" )[- 1 ]
331
+ torch .save (all_point_tensor , os .path .join (args .output_dir , f"{ video_name } .npy" ))
332
+ # write csv
333
+ csv_headers = ['frame' ]
334
+ for keypoint in COCO_KEYPOINT_INDEXES .values ():
335
+ csv_headers .extend ([keypoint + '_x' , keypoint + '_y' ])
336
+
337
+ csv_output_filename = os .path .join (args .output_dir , f'{ video_name } .csv' )
338
+ with open (csv_output_filename , 'w' , newline = '' ) as csvfile :
339
+ csvwriter = csv .writer (csvfile )
340
+ csvwriter .writerow (csv_headers )
341
+ csvwriter .writerows (csv_output_rows )
342
+ print (f"finished getting keypoints of video { video_name } " )
324
343
325
344
326
345
if __name__ == '__main__' :
0 commit comments