Skip to content

Commit e1af02e

Browse files
committed
cherry pick from human pose estimation
1 parent b0199ef commit e1af02e

File tree

16 files changed

+9126
-49
lines changed

16 files changed

+9126
-49
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "lib/third_part/poseval"]
2+
path = lib/third_part/poseval
3+
url = [email protected]:longcw/poseval.git

experiments/coco/hrnet/w32_256x192_adam_lr1e-3.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ CUDNN:
44
DETERMINISTIC: false
55
ENABLED: true
66
DATA_DIR: ''
7-
GPUS: (0,1,2,3)
7+
GPUS: (0,)
88
OUTPUT_DIR: 'output'
99
LOG_DIR: 'log'
1010
WORKERS: 24

experiments/mpii/hrnet/w32_256x256_adam_lr1e-3.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ CUDNN:
44
DETERMINISTIC: false
55
ENABLED: true
66
DATA_DIR: ''
7-
GPUS: (0,1,2,3)
7+
GPUS: (0,)
88
OUTPUT_DIR: 'output'
99
LOG_DIR: 'log'
1010
WORKERS: 24

lib/core/evaluate.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,3 @@ def accuracy(output, target, hm_type='gaussian', thr=0.5):
6969
if cnt != 0:
7070
acc[0] = avg_acc
7171
return acc, avg_acc, cnt, pred
72-
73-

lib/core/function.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77
from __future__ import absolute_import
88
from __future__ import division
99
from __future__ import print_function
10-
11-
import time
10+
1211
import logging
12+
import time
1313
import os
14+
from collections import defaultdict
15+
import ujson as json
1416

1517
import numpy as np
1618
import torch
1719

1820
from core.evaluate import accuracy
19-
from core.inference import get_final_preds
21+
from core.inference import get_final_preds, get_max_preds
2022
from utils.transforms import flip_back
2123
from utils.vis import save_debug_images
2224

@@ -110,8 +112,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
110112
)
111113
all_boxes = np.zeros((num_samples, 6))
112114
image_path = []
113-
filenames = []
114-
imgnums = []
115+
image_ids = []
115116
idx = 0
116117
with torch.no_grad():
117118
end = time.time()
@@ -179,6 +180,8 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
179180
all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
180181
all_boxes[idx:idx + num_images, 5] = score
181182
image_path.extend(meta['image'])
183+
if config.DATASET.DATASET == 'posetrack':
184+
image_ids.extend(meta['image_id'].numpy())
182185

183186
idx += num_images
184187

@@ -198,8 +201,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
198201
prefix)
199202

200203
name_values, perf_indicator = val_dataset.evaluate(
201-
config, all_preds, output_dir, all_boxes, image_path,
202-
filenames, imgnums
204+
config, all_preds, output_dir, all_boxes, image_path, image_ids=image_ids
203205
)
204206

205207
model_name = config.MODEL.NAME
@@ -240,6 +242,107 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
240242
return perf_indicator
241243

242244

245+
def inference(config, image_loader, image_dataset, model, output_dir):
246+
batch_time = AverageMeter()
247+
losses = AverageMeter()
248+
acc = AverageMeter()
249+
250+
# switch to evaluate mode
251+
model.eval()
252+
253+
num_samples = len(image_dataset)
254+
all_preds = np.zeros((num_samples, config.MODEL.NUM_JOINTS, 3),
255+
dtype=np.float32)
256+
all_boxes = np.zeros((num_samples, 5))
257+
all_image_pathes = []
258+
all_image_ids = []
259+
idx = 0
260+
with torch.no_grad():
261+
end = time.time()
262+
for i, (input, target, target_weight, meta) in enumerate(image_loader):
263+
num_images = input.size(0)
264+
# compute output
265+
outputs = model(input)
266+
if isinstance(outputs, list):
267+
output = outputs[-1]
268+
else:
269+
output = outputs
270+
271+
if config.TEST.FLIP_TEST:
272+
# this part is ugly, because pytorch has not supported negative index
273+
# input_flipped = model(input[:, :, :, ::-1])
274+
input_flipped = np.flip(input.cpu().numpy(), 3).copy()
275+
input_flipped = torch.from_numpy(input_flipped).cuda()
276+
outputs_flipped = model(input_flipped)
277+
if isinstance(outputs_flipped, list):
278+
output_flipped = outputs_flipped[-1]
279+
else:
280+
output_flipped = outputs_flipped
281+
282+
output_flipped = flip_back(output_flipped.cpu().numpy(),
283+
image_dataset.flip_pairs)
284+
output_flipped = torch.from_numpy(output_flipped.copy()).cuda()
285+
286+
# feature is not aligned, shift flipped heatmap for higher accuracy
287+
if config.TEST.SHIFT_HEATMAP:
288+
output_flipped[:, :, :, 1:] = \
289+
output_flipped.clone()[:, :, :, 0:-1]
290+
# output_flipped[:, :, :, 0] = 0
291+
292+
output = (output + output_flipped) * 0.5
293+
294+
# measure elapsed time
295+
batch_time.update(time.time() - end)
296+
end = time.time()
297+
298+
c = meta['center'].numpy()
299+
s = meta['scale'].numpy()
300+
score = meta['score'].numpy()
301+
tlwhs = meta['bbox_tlwh'].numpy()
302+
output = output.data.cpu()
303+
304+
preds, maxvals = get_final_preds(config, output.numpy(), c, s)
305+
306+
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
307+
all_preds[idx:idx + num_images, :, 2:3] = maxvals
308+
# double check this all_boxes parts
309+
all_boxes[idx:idx + num_images, 0:4] = tlwhs
310+
all_boxes[idx:idx + num_images, 4] = score
311+
all_image_pathes.extend(meta['image'])
312+
if config.DATASET.DATASET == 'mot':
313+
seq_names, frame_ids = meta['image_id']
314+
frame_ids = frame_ids.numpy().astype(int)
315+
all_image_ids.extend(list(zip(seq_names, frame_ids)))
316+
317+
idx += num_images
318+
319+
if i % config.PRINT_FREQ == 0:
320+
msg = 'Test: [{0}/{1}]\t' \
321+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(
322+
i, len(image_loader), batch_time=batch_time)
323+
logger.info(msg)
324+
325+
prefix = '{}_{}'.format(os.path.join(output_dir, 'inference'), i)
326+
pred, _ = get_max_preds(output.numpy())
327+
save_debug_images(config, input, meta, target, pred * 4, output, prefix)
328+
329+
# write output
330+
frame_results = defaultdict(list)
331+
for image_id, pred, box in zip(all_image_ids, all_preds, all_boxes):
332+
frame_results[image_id].append((pred.astype(float).tolist(), box.astype(float).tolist()))
333+
334+
final_results = {}
335+
for image_id, results in frame_results.items():
336+
keypoints, boxes = zip(*results)
337+
final_results[image_id] = {'keypoints': keypoints, 'boxes': boxes}
338+
339+
if not os.path.isdir(output_dir):
340+
os.makedirs(output_dir)
341+
with open(os.path.join(output_dir, 'box_keypoints.json'), 'w') as f:
342+
json.dump(final_results, f)
343+
logger.info('Save results to {}'.format(os.path.join(output_dir, 'box_keypoints.json')))
344+
345+
243346
# markdown format output
244347
def _print_name_value(name_value, full_arch_name):
245348
names = name_value.keys()

lib/dataset/JointsDataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def __init__(self, cfg, root, image_set, is_train, transform=None):
4646
self.prob_half_body = cfg.DATASET.PROB_HALF_BODY
4747
self.color_rgb = cfg.DATASET.COLOR_RGB
4848

49-
self.target_type = cfg.MODEL.TARGET_TYPE
5049
self.image_size = np.array(cfg.MODEL.IMAGE_SIZE)
50+
self.target_type = cfg.MODEL.TARGET_TYPE
5151
self.heatmap_size = np.array(cfg.MODEL.HEATMAP_SIZE)
5252
self.sigma = cfg.MODEL.SIGMA
5353
self.use_different_joints_weight = cfg.LOSS.USE_DIFFERENT_JOINTS_WEIGHT
@@ -114,8 +114,8 @@ def __getitem__(self, idx):
114114
db_rec = copy.deepcopy(self.db[idx])
115115

116116
image_file = db_rec['image']
117-
filename = db_rec['filename'] if 'filename' in db_rec else ''
118-
imgnum = db_rec['imgnum'] if 'imgnum' in db_rec else ''
117+
image_id = db_rec.get('image_id', -1)
118+
bbox_tlwh = db_rec.get('bbox_tlwh', (0, 0, 0, 0))
119119

120120
if self.data_format == 'zip':
121121
from utils import zipreader
@@ -185,8 +185,8 @@ def __getitem__(self, idx):
185185

186186
meta = {
187187
'image': image_file,
188-
'filename': filename,
189-
'imgnum': imgnum,
188+
'image_id': image_id,
189+
'bbox_tlwh': bbox_tlwh,
190190
'joints': joints,
191191
'joints_vis': joints_vis,
192192
'center': c,

lib/dataset/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@
1010

1111
from .mpii import MPIIDataset as mpii
1212
from .coco import COCODataset as coco
13+
from .posetrack import PoseTrackDataset as posetrack
14+
from .mot import MOTDataset as mot

lib/dataset/coco.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88
from __future__ import division
99
from __future__ import print_function
1010

11-
from collections import defaultdict
12-
from collections import OrderedDict
1311
import logging
1412
import os
13+
import pickle
14+
from collections import defaultdict
15+
from collections import OrderedDict
1516

17+
# import json_tricks as json
18+
import ujson as json
19+
import numpy as np
1620
from pycocotools.coco import COCO
1721
from pycocotools.cocoeval import COCOeval
18-
import json_tricks as json
19-
import numpy as np
2022

2123
from dataset.JointsDataset import JointsDataset
2224
from nms.nms import oks_nms
@@ -51,8 +53,15 @@ class COCODataset(JointsDataset):
5153
[16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13], [6,7],[6,8],
5254
[7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]]
5355
'''
56+
57+
num_joints = 17
58+
flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
59+
[9, 10], [11, 12], [13, 14], [15, 16]]
60+
upper_body_ids = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
61+
lower_body_ids = (11, 12, 13, 14, 15, 16)
62+
5463
def __init__(self, cfg, root, image_set, is_train, transform=None):
55-
super().__init__(cfg, root, image_set, is_train, transform)
64+
super(COCODataset, self).__init__(cfg, root, image_set, is_train, transform)
5665
self.nms_thre = cfg.TEST.NMS_THRE
5766
self.image_thre = cfg.TEST.IMAGE_THRE
5867
self.soft_nms = cfg.TEST.SOFT_NMS
@@ -87,12 +96,11 @@ def __init__(self, cfg, root, image_set, is_train, transform=None):
8796
self.num_images = len(self.image_set_index)
8897
logger.info('=> num_images: {}'.format(self.num_images))
8998

90-
self.num_joints = 17
91-
self.flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
92-
[9, 10], [11, 12], [13, 14], [15, 16]]
99+
self.num_joints = COCODataset.num_joints
100+
self.flip_pairs = COCODataset.flip_pairs
101+
self.upper_body_ids = COCODataset.upper_body_ids
102+
self.lower_body_ids = COCODataset.lower_body_ids
93103
self.parent_ids = None
94-
self.upper_body_ids = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
95-
self.lower_body_ids = (11, 12, 13, 14, 15, 16)
96104

97105
self.joints_weight = np.array(
98106
[
@@ -201,8 +209,6 @@ def _load_coco_keypoint_annotation_kernal(self, index):
201209
'scale': scale,
202210
'joints_3d': joints_3d,
203211
'joints_3d_vis': joints_3d_vis,
204-
'filename': '',
205-
'imgnum': 0,
206212
})
207213

208214
return rec
@@ -378,7 +384,7 @@ def _write_coco_keypoint_results(self, keypoints, res_file):
378384
]
379385

380386
results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
381-
logger.info('=> writing results json to %s' % res_file)
387+
logger.info('=> Writing results json to %s' % res_file)
382388
with open(res_file, 'w') as f:
383389
json.dump(results, f, sort_keys=True, indent=4)
384390
try:

0 commit comments

Comments
 (0)