Skip to content

Commit 5f7bf08

Browse files
committed
modify script for keypoint generation
1 parent 8268d4c commit 5f7bf08

File tree

1 file changed

+103
-84
lines changed

1 file changed

+103
-84
lines changed

demo/inference.py

Lines changed: 103 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def parse_args():
176176
parser = argparse.ArgumentParser(description='Train keypoints network')
177177
# general
178178
parser.add_argument('--cfg', type=str, required=True)
179-
parser.add_argument('--video_dir', type=str, required=True)
179+
parser.add_argument('--video_root', type=str, required=True)
180180
parser.add_argument('--output_dir', type=str, default='/output/')
181181
parser.add_argument('--writeBoxFrames', action='store_true')
182182

@@ -201,10 +201,13 @@ def get_image_paths(video_path):
201201
images_list = full_image_paths
202202
return images_list
203203

204+
def get_keypoints_for_video(video_path):
205+
return
206+
204207
def main():
208+
print("change from pycharm")
205209
# transformation
206210
pose_transform = transforms.Compose([
207-
transforms.Resize((224,224)),
208211
transforms.ToTensor(),
209212
transforms.Normalize(mean=[0.485, 0.456, 0.406],
210213
std=[0.229, 0.224, 0.225]),
@@ -217,7 +220,7 @@ def main():
217220

218221
args = parse_args()
219222
update_config(cfg, args)
220-
video_path = args.video_dir
223+
video_root = args.video_root
221224
pose_dir = prepare_output_dirs(args.output_dir)
222225
csv_output_rows = []
223226

@@ -237,90 +240,106 @@ def main():
237240
pose_model.to(CTX)
238241
pose_model.eval()
239242

240-
# Loading an video
241-
# vidcap = cv2.VideoCapture(args.videoFile)
242-
count = 0
243-
for image_path in get_image_paths(video_path):
244-
print(f"reading image {image_path}")
245-
image_bgr = cv2.imread(image_path)
246-
total_now = time.time()
247-
count += 1
248-
249-
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
250-
251-
# Clone 2 image for person detection and pose estimation
252-
if cfg.DATASET.COLOR_RGB:
253-
image_per = image_rgb.copy()
254-
image_pose = image_rgb.copy()
255-
else:
256-
image_per = image_bgr.copy()
257-
image_pose = image_bgr.copy()
258-
259-
# Clone 1 image for debugging purpose
260-
image_debug = image_bgr.copy()
261-
262-
# object detection box
263-
now = time.time()
264-
pred_boxes = get_person_detection_boxes(box_model, image_per, threshold=0.9)
265-
then = time.time()
266-
print("Find person bbox in: {} sec".format(then - now))
267-
268-
# Can not find people. Move to next frame
269-
if not pred_boxes:
243+
video_names = os.listdir(video_root)
244+
video_path_list = [os.path.join(video_root, x) for x in video_names]
245+
246+
for video_path in video_path_list:
247+
# Loading an video
248+
# vidcap = cv2.VideoCapture(args.videoFile)
249+
count = 0
250+
tensor_list = []
251+
252+
for image_path in get_image_paths(video_path):
253+
# print(f"reading image {image_path}")
254+
image_bgr = cv2.imread(image_path)
255+
# print(f"image has shape {image_bgr.shape}")
256+
total_now = time.time()
270257
count += 1
271-
continue
272258

273-
if args.writeBoxFrames:
259+
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
260+
261+
# Clone 2 image for person detection and pose estimation
262+
if cfg.DATASET.COLOR_RGB:
263+
image_per = image_rgb.copy()
264+
image_pose = image_rgb.copy()
265+
else:
266+
image_per = image_bgr.copy()
267+
image_pose = image_bgr.copy()
268+
269+
# Clone 1 image for debugging purpose
270+
image_debug = image_bgr.copy()
271+
272+
# object detection box
273+
now = time.time()
274+
pred_boxes = get_person_detection_boxes(box_model, image_per, threshold=0.9)
275+
then = time.time()
276+
# print("Find person bbox in: {} sec".format(then - now))
277+
278+
# Can not find people. Move to next frame
279+
if not pred_boxes:
280+
count += 1
281+
continue
282+
283+
if args.writeBoxFrames:
284+
for box in pred_boxes:
285+
cv2.rectangle(image_debug, box[0], box[1], color=(0, 255, 0),
286+
thickness=3) # Draw Rectangle with the coordinates
287+
288+
# pose estimation : for multiple people
289+
centers = []
290+
scales = []
274291
for box in pred_boxes:
275-
cv2.rectangle(image_debug, box[0], box[1], color=(0, 255, 0),
276-
thickness=3) # Draw Rectangle with the coordinates
277-
278-
# pose estimation : for multiple people
279-
centers = []
280-
scales = []
281-
for box in pred_boxes:
282-
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
283-
centers.append(center)
284-
scales.append(scale)
285-
286-
now = time.time()
287-
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, centers, scales, transform=pose_transform)
288-
then = time.time()
289-
print("Find person pose in: {} sec".format(then - now))
290-
291-
new_csv_row = []
292-
for coords in pose_preds:
293-
# Draw each point on image
294-
for coord in coords:
295-
x_coord, y_coord = int(coord[0]), int(coord[1])
296-
cv2.circle(image_debug, (x_coord, y_coord), 4, (255, 0, 0), 2)
297-
new_csv_row.extend([x_coord, y_coord])
298-
299-
total_then = time.time()
300-
301-
text = "{:03.2f} sec".format(total_then - total_now)
302-
cv2.putText(image_debug, text, (100, 50), cv2.FONT_HERSHEY_SIMPLEX,
303-
1, (0, 0, 255), 2, cv2.LINE_AA)
304-
305-
cv2.imshow("pos", image_debug)
306-
if cv2.waitKey(1) & 0xFF == ord('q'):
307-
break
308-
309-
csv_output_rows.append(new_csv_row)
310-
img_file = os.path.join(pose_dir, 'pose_{:08d}.jpg'.format(count))
311-
cv2.imwrite(img_file, image_debug)
312-
313-
314-
# write csv
315-
csv_headers = ['frame']
316-
for keypoint in COCO_KEYPOINT_INDEXES.values():
317-
csv_headers.extend([keypoint+'_x', keypoint+'_y'])
318-
319-
csv_output_filename = os.path.join(args.outputDir, 'pose-data.csv')
320-
with open(csv_output_filename, 'w', newline='') as csvfile:
321-
csvwriter = csv.writer(csvfile)
322-
csvwriter.writerow(csv_headers)
323-
csvwriter.writerows(csv_output_rows)
292+
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
293+
centers.append(center)
294+
scales.append(scale)
295+
296+
now = time.time()
297+
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, centers, scales, transform=pose_transform)
298+
then = time.time()
299+
# print("Find person pose in: {} sec".format(then - now))
300+
301+
new_csv_row = []
302+
needed_indices = [0,5,6,7,8,9,10]
303+
needed_points = torch.from_numpy(pose_preds[:, needed_indices, :])
304+
# print(f"pose pred is {pose_preds.shape}")
305+
tensor_list.append(needed_points)
306+
for coords in needed_points:
307+
# Draw each point on image
308+
for coord in coords:
309+
x_coord, y_coord = int(coord[0]), int(coord[1])
310+
cv2.circle(image_debug, (x_coord, y_coord), 4, (255, 0, 0), 2)
311+
new_csv_row.extend([x_coord, y_coord])
312+
313+
total_then = time.time()
314+
315+
# text = "{:03.2f} sec".format(total_then - total_now)
316+
# cv2.putText(image_debug, text, (100, 50), cv2.FONT_HERSHEY_SIMPLEX,
317+
# 1, (0, 0, 255), 2, cv2.LINE_AA)
318+
319+
# cv2.imshow("pos", image_debug)
320+
# if cv2.waitKey(1) & 0xFF == ord('q'):
321+
# break
322+
323+
csv_output_rows.append(new_csv_row)
324+
# img_file = os.path.join(pose_dir, 'pose_{:08d}.jpg'.format(count))
325+
# cv2.imwrite(img_file, image_debug)
326+
327+
328+
all_point_tensor = torch.cat(tensor_list)
329+
print(f"all point tensor has shape {all_point_tensor.shape}")
330+
video_name = video_path.split("/")[-1]
331+
torch.save(all_point_tensor, os.path.join(args.output_dir, f"{video_name}.npy"))
332+
# write csv
333+
csv_headers = ['frame']
334+
for keypoint in COCO_KEYPOINT_INDEXES.values():
335+
csv_headers.extend([keypoint+'_x', keypoint+'_y'])
336+
337+
csv_output_filename = os.path.join(args.output_dir, f'{video_name}.csv')
338+
with open(csv_output_filename, 'w', newline='') as csvfile:
339+
csvwriter = csv.writer(csvfile)
340+
csvwriter.writerow(csv_headers)
341+
csvwriter.writerows(csv_output_rows)
342+
print(f"finished getting keypoints of video {video_name}")
324343

325344

326345
if __name__ == '__main__':

0 commit comments

Comments
 (0)