1919import cv2
2020import numpy as np
2121
22+ import sys
23+ sys .path .append ("../lib" )
24+ import time
2225
23- import _init_paths
26+ # import _init_paths
2427import models
2528from config import cfg
2629from config import update_config
27- from core .function import get_final_preds
30+ from core .inference import get_final_preds
2831from utils .transforms import get_affine_transform
2932
33+ CTX = torch .device ('cuda' ) if torch .cuda .is_available () else torch .device ('cpu' )
34+
35+
3036COCO_KEYPOINT_INDEXES = {
3137 0 : 'nose' ,
3238 1 : 'left_eye' ,
@@ -67,57 +73,53 @@ def get_person_detection_boxes(model, img, threshold=0.5):
6773 pil_image = Image .fromarray (img ) # Load the image
6874 transform = transforms .Compose ([transforms .ToTensor ()]) # Defing PyTorch Transform
6975 transformed_img = transform (pil_image ) # Apply the transform to the image
70- pred = model ([transformed_img ]) # Pass the image to the model
76+ pred = model ([transformed_img .to (CTX )]) # Pass the image to the model
77+ # Use the first detected person
7178 pred_classes = [COCO_INSTANCE_CATEGORY_NAMES [i ]
72- for i in list (pred [0 ]['labels' ].numpy ())] # Get the Prediction Score
79+ for i in list (pred [0 ]['labels' ].cpu (). numpy ())] # Get the Prediction Score
7380 pred_boxes = [[(i [0 ], i [1 ]), (i [2 ], i [3 ])]
74- for i in list (pred [0 ]['boxes' ].detach ().numpy ())] # Bounding boxes
75- pred_score = list (pred [0 ]['scores' ].detach ().numpy ())
76- if not pred_score :
77- return []
78- # Get list of index with score greater than threshold
79- pred_t = [pred_score .index (x ) for x in pred_score if x > threshold ][- 1 ]
80- pred_boxes = pred_boxes [:pred_t + 1 ]
81- pred_classes = pred_classes [:pred_t + 1 ]
81+ for i in list (pred [0 ]['boxes' ].cpu ().detach ().numpy ())] # Bounding boxes
82+ pred_scores = list (pred [0 ]['scores' ].cpu ().detach ().numpy ())
8283
8384 person_boxes = []
84- for idx , box in enumerate (pred_boxes ):
85- if pred_classes [idx ] == 'person' :
86- person_boxes .append (box )
85+ # Select box has score larger than threshold and is person
86+ for pred_class , pred_box , pred_score in zip (pred_classes , pred_boxes , pred_scores ):
87+ if (pred_score > threshold ) and (pred_class == 'person' ):
88+ person_boxes .append (pred_box )
8789
8890 return person_boxes
8991
9092
91- def get_pose_estimation_prediction (pose_model , image , center , scale ):
93+ def get_pose_estimation_prediction (pose_model , image , centers , scales , transform ):
9294 rotation = 0
9395
9496 # pose estimation transformation
95- trans = get_affine_transform ( center , scale , rotation , cfg . MODEL . IMAGE_SIZE )
96- model_input = cv2 . warpAffine (
97- image ,
98- trans ,
99- ( int ( cfg . MODEL . IMAGE_SIZE [ 0 ]), int ( cfg . MODEL . IMAGE_SIZE [ 1 ])),
100- flags = cv2 . INTER_LINEAR )
101- transform = transforms . Compose ([
102- transforms . ToTensor ( ),
103- transforms . Normalize ( mean = [ 0.485 , 0.456 , 0.406 ],
104- std = [ 0.229 , 0.224 , 0.225 ]),
105- ])
106-
107- # pose estimation inference
108- model_input = transform ( model_input ). unsqueeze ( 0 )
109- # switch to evaluate mode
110- pose_model . eval ( )
111- with torch . no_grad ():
112- # compute output heatmap
113- output = pose_model (model_input )
114- preds , _ = get_final_preds (
115- cfg ,
116- output .clone ().cpu ().numpy (),
117- np .asarray ([ center ] ),
118- np .asarray ([ scale ] ))
119-
120- return preds
97+ model_inputs = []
98+ for center , scale in zip ( centers , scales ):
99+ trans = get_affine_transform ( center , scale , rotation , cfg . MODEL . IMAGE_SIZE )
100+ # Crop smaller image of people
101+ model_input = cv2 . warpAffine (
102+ image ,
103+ trans ,
104+ ( int ( cfg . MODEL . IMAGE_SIZE [ 0 ]), int ( cfg . MODEL . IMAGE_SIZE [ 1 ]) ),
105+ flags = cv2 . INTER_LINEAR )
106+
107+ # hwc -> 1chw
108+ model_input = transform ( model_input ) #.unsqueeze(0)
109+ model_inputs . append ( model_input )
110+
111+ # n * 1chw -> nchw
112+ model_inputs = torch . stack ( model_inputs )
113+
114+ # compute output heatmap
115+ output = pose_model (model_inputs . to ( CTX ) )
116+ coords , _ = get_final_preds (
117+ cfg ,
118+ output .cpu ().detach ().numpy (),
119+ np .asarray (centers ),
120+ np .asarray (scales ))
121+
122+ return coords
121123
122124
123125def box_to_center_scale (box , model_image_width , model_image_height ):
@@ -163,15 +165,11 @@ def box_to_center_scale(box, model_image_width, model_image_height):
163165
164166
165167def prepare_output_dirs (prefix = '/output/' ):
166- pose_dir = prefix + 'poses/'
167- box_dir = prefix + 'boxes/'
168+ pose_dir = os .path .join (prefix , "pose" )
168169 if os .path .exists (pose_dir ) and os .path .isdir (pose_dir ):
169170 shutil .rmtree (pose_dir )
170- if os .path .exists (box_dir ) and os .path .isdir (box_dir ):
171- shutil .rmtree (box_dir )
172171 os .makedirs (pose_dir , exist_ok = True )
173- os .makedirs (box_dir , exist_ok = True )
174- return pose_dir , box_dir
172+ return pose_dir
175173
176174
177175def parse_args ():
@@ -199,20 +197,26 @@ def parse_args():
199197
200198
201199def main ():
200+ # transformation
201+ pose_transform = transforms .Compose ([
202+ transforms .ToTensor (),
203+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
204+ std = [0.229 , 0.224 , 0.225 ]),
205+ ])
206+
202207 # cudnn related setting
203208 cudnn .benchmark = cfg .CUDNN .BENCHMARK
204209 torch .backends .cudnn .deterministic = cfg .CUDNN .DETERMINISTIC
205210 torch .backends .cudnn .enabled = cfg .CUDNN .ENABLED
206211
207212 args = parse_args ()
208213 update_config (cfg , args )
209- pose_dir , box_dir = prepare_output_dirs (args .outputDir )
210- csv_output_filename = args .outputDir + 'pose-data.csv'
214+ pose_dir = prepare_output_dirs (args .outputDir )
211215 csv_output_rows = []
212216
213217 box_model = torchvision .models .detection .fasterrcnn_resnet50_fpn (pretrained = True )
218+ box_model .to (CTX )
214219 box_model .eval ()
215-
216220 pose_model = eval ('models.' + cfg .MODEL .NAME + '.get_pose_net' )(
217221 cfg , is_train = False
218222 )
@@ -223,76 +227,114 @@ def main():
223227 else :
224228 print ('expected model defined in config at TEST.MODEL_FILE' )
225229
226- pose_model = torch .nn .DataParallel (pose_model , device_ids = cfg .GPUS ).cuda ()
230+ pose_model .to (CTX )
231+ pose_model .eval ()
227232
228233 # Loading an video
229234 vidcap = cv2 .VideoCapture (args .videoFile )
230235 fps = vidcap .get (cv2 .CAP_PROP_FPS )
231236 if fps < args .inferenceFps :
232237 print ('desired inference fps is ' + str (args .inferenceFps )+ ' but video fps is ' + str (fps ))
233238 exit ()
234- every_nth_frame = round (fps / args .inferenceFps )
239+ skip_frame_cnt = round (fps / args .inferenceFps )
240+ frame_width = int (vidcap .get (cv2 .CAP_PROP_FRAME_WIDTH ))
241+ frame_height = int (vidcap .get (cv2 .CAP_PROP_FRAME_HEIGHT ))
242+ outcap = cv2 .VideoWriter ('{}/{}_pose.avi' .format (args .outputDir , os .path .splitext (os .path .basename (args .videoFile ))[0 ]),
243+ cv2 .VideoWriter_fourcc ('M' , 'J' , 'P' , 'G' ), int (skip_frame_cnt ), (frame_width , frame_height ))
235244
236- success , image_bgr = vidcap .read ()
237245 count = 0
246+ while vidcap .isOpened ():
247+ total_now = time .time ()
248+ ret , image_bgr = vidcap .read ()
249+ count += 1
238250
239- while success :
240- if count % every_nth_frame != 0 :
241- success , image_bgr = vidcap .read ()
242- count += 1
251+ if not ret :
243252 continue
244253
245- image = image_bgr [:, :, [2 , 1 , 0 ]]
246- count_str = str (count ).zfill (32 )
254+ if count % skip_frame_cnt != 0 :
255+ continue
256+
257+ image_rgb = cv2 .cvtColor (image_bgr , cv2 .COLOR_BGR2RGB )
258+
259+ # Clone 2 image for person detection and pose estimation
260+ if cfg .DATASET .COLOR_RGB :
261+ image_per = image_rgb .copy ()
262+ image_pose = image_rgb .copy ()
263+ else :
264+ image_per = image_bgr .copy ()
265+ image_pose = image_bgr .copy ()
266+
267+ # Clone 1 image for debugging purpose
268+ image_debug = image_bgr .copy ()
247269
248270 # object detection box
249- pred_boxes = get_person_detection_boxes (box_model , image , threshold = 0.8 )
250- if args .writeBoxFrames :
251- image_bgr_box = image_bgr .copy ()
252- for box in pred_boxes :
253- cv2 .rectangle (image_bgr_box , box [0 ], box [1 ], color = (0 , 255 , 0 ),
254- thickness = 3 ) # Draw Rectangle with the coordinates
255- cv2 .imwrite (box_dir + 'box%s.jpg' % count_str , image_bgr_box )
271+ now = time .time ()
272+ pred_boxes = get_person_detection_boxes (box_model , image_per , threshold = 0.9 )
273+ then = time .time ()
274+ print ("Find person bbox in: {} sec" .format (then - now ))
275+
276+ # Can not find people. Move to next frame
256277 if not pred_boxes :
257- success , image_bgr = vidcap .read ()
258278 count += 1
259279 continue
260280
261- # pose estimation
262- box = pred_boxes [0 ] # assume there is only 1 person
263- center , scale = box_to_center_scale (box , cfg .MODEL .IMAGE_SIZE [0 ], cfg .MODEL .IMAGE_SIZE [1 ])
264- image_pose = image .copy () if cfg .DATASET .COLOR_RGB else image_bgr .copy ()
265- pose_preds = get_pose_estimation_prediction (pose_model , image_pose , center , scale )
281+ if args .writeBoxFrames :
282+ for box in pred_boxes :
283+ cv2 .rectangle (image_debug , box [0 ], box [1 ], color = (0 , 255 , 0 ),
284+ thickness = 3 ) # Draw Rectangle with the coordinates
285+
286+ # pose estimation : for multiple people
287+ centers = []
288+ scales = []
289+ for box in pred_boxes :
290+ center , scale = box_to_center_scale (box , cfg .MODEL .IMAGE_SIZE [0 ], cfg .MODEL .IMAGE_SIZE [1 ])
291+ centers .append (center )
292+ scales .append (scale )
293+
294+ now = time .time ()
295+ pose_preds = get_pose_estimation_prediction (pose_model , image_pose , centers , scales , transform = pose_transform )
296+ then = time .time ()
297+ print ("Find person pose in: {} sec" .format (then - now ))
266298
267299 new_csv_row = []
268- for _ , mat in enumerate (pose_preds [0 ]):
269- x_coord , y_coord = int (mat [0 ]), int (mat [1 ])
270- cv2 .circle (image_bgr , (x_coord , y_coord ), 4 , (255 , 0 , 0 ), 2 )
271- new_csv_row .extend ([x_coord , y_coord ])
300+ for coords in pose_preds :
301+ # Draw each point on image
302+ for coord in coords :
303+ x_coord , y_coord = int (coord [0 ]), int (coord [1 ])
304+ cv2 .circle (image_debug , (x_coord , y_coord ), 4 , (255 , 0 , 0 ), 2 )
305+ new_csv_row .extend ([x_coord , y_coord ])
306+
307+ total_then = time .time ()
308+
309+ text = "{:03.2f} sec" .format (total_then - total_now )
310+ cv2 .putText (image_debug , text , (100 , 50 ), cv2 .FONT_HERSHEY_SIMPLEX ,
311+ 1 , (0 , 0 , 255 ), 2 , cv2 .LINE_AA )
312+
313+ cv2 .imshow ("pos" , image_debug )
314+ if cv2 .waitKey (1 ) & 0xFF == ord ('q' ):
315+ break
272316
273317 csv_output_rows .append (new_csv_row )
274- cv2 .imwrite (pose_dir + 'pose%s.jpg' % count_str , image_bgr )
318+ img_file = os .path .join (pose_dir , 'pose_{:08d}.jpg' .format (count ))
319+ cv2 .imwrite (img_file , image_debug )
320+ outcap .write (image_debug )
275321
276- # get next frame
277- success , image_bgr = vidcap .read ()
278- count += 1
279322
280323 # write csv
281324 csv_headers = ['frame' ]
282325 for keypoint in COCO_KEYPOINT_INDEXES .values ():
283326 csv_headers .extend ([keypoint + '_x' , keypoint + '_y' ])
284327
328+ csv_output_filename = os .path .join (args .outputDir , 'pose-data.csv' )
285329 with open (csv_output_filename , 'w' , newline = '' ) as csvfile :
286330 csvwriter = csv .writer (csvfile )
287331 csvwriter .writerow (csv_headers )
288332 csvwriter .writerows (csv_output_rows )
289333
290- os .system ("ffmpeg -y -r "
291- + str (args .inferenceFps )
292- + " -pattern_type glob -i '"
293- + pose_dir
294- + "/*.jpg' -c:v libx264 -vf fps="
295- + str (args .inferenceFps )+ " -pix_fmt yuv420p /output/movie.mp4" )
334+ vidcap .release ()
335+ outcap .release ()
336+
337+ cv2 .destroyAllWindows ()
296338
297339
298340if __name__ == '__main__' :
0 commit comments