7
7
from __future__ import absolute_import
8
8
from __future__ import division
9
9
from __future__ import print_function
10
-
11
- import time
10
+
12
11
import logging
12
+ import time
13
13
import os
14
+ from collections import defaultdict
15
+ import ujson as json
14
16
15
17
import numpy as np
16
18
import torch
17
19
18
20
from core .evaluate import accuracy
19
- from core .inference import get_final_preds
21
+ from core .inference import get_final_preds , get_max_preds
20
22
from utils .transforms import flip_back
21
23
from utils .vis import save_debug_images
22
24
@@ -110,8 +112,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
110
112
)
111
113
all_boxes = np .zeros ((num_samples , 6 ))
112
114
image_path = []
113
- filenames = []
114
- imgnums = []
115
+ image_ids = []
115
116
idx = 0
116
117
with torch .no_grad ():
117
118
end = time .time ()
@@ -179,6 +180,8 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
179
180
all_boxes [idx :idx + num_images , 4 ] = np .prod (s * 200 , 1 )
180
181
all_boxes [idx :idx + num_images , 5 ] = score
181
182
image_path .extend (meta ['image' ])
183
+ if config .DATASET .DATASET == 'posetrack' :
184
+ image_ids .extend (meta ['image_id' ].numpy ())
182
185
183
186
idx += num_images
184
187
@@ -198,8 +201,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
198
201
prefix )
199
202
200
203
name_values , perf_indicator = val_dataset .evaluate (
201
- config , all_preds , output_dir , all_boxes , image_path ,
202
- filenames , imgnums
204
+ config , all_preds , output_dir , all_boxes , image_path , image_ids = image_ids
203
205
)
204
206
205
207
model_name = config .MODEL .NAME
@@ -240,6 +242,107 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
240
242
return perf_indicator
241
243
242
244
245
+ def inference (config , image_loader , image_dataset , model , output_dir ):
246
+ batch_time = AverageMeter ()
247
+ losses = AverageMeter ()
248
+ acc = AverageMeter ()
249
+
250
+ # switch to evaluate mode
251
+ model .eval ()
252
+
253
+ num_samples = len (image_dataset )
254
+ all_preds = np .zeros ((num_samples , config .MODEL .NUM_JOINTS , 3 ),
255
+ dtype = np .float32 )
256
+ all_boxes = np .zeros ((num_samples , 5 ))
257
+ all_image_pathes = []
258
+ all_image_ids = []
259
+ idx = 0
260
+ with torch .no_grad ():
261
+ end = time .time ()
262
+ for i , (input , target , target_weight , meta ) in enumerate (image_loader ):
263
+ num_images = input .size (0 )
264
+ # compute output
265
+ outputs = model (input )
266
+ if isinstance (outputs , list ):
267
+ output = outputs [- 1 ]
268
+ else :
269
+ output = outputs
270
+
271
+ if config .TEST .FLIP_TEST :
272
+ # this part is ugly, because pytorch has not supported negative index
273
+ # input_flipped = model(input[:, :, :, ::-1])
274
+ input_flipped = np .flip (input .cpu ().numpy (), 3 ).copy ()
275
+ input_flipped = torch .from_numpy (input_flipped ).cuda ()
276
+ outputs_flipped = model (input_flipped )
277
+ if isinstance (outputs_flipped , list ):
278
+ output_flipped = outputs_flipped [- 1 ]
279
+ else :
280
+ output_flipped = outputs_flipped
281
+
282
+ output_flipped = flip_back (output_flipped .cpu ().numpy (),
283
+ image_dataset .flip_pairs )
284
+ output_flipped = torch .from_numpy (output_flipped .copy ()).cuda ()
285
+
286
+ # feature is not aligned, shift flipped heatmap for higher accuracy
287
+ if config .TEST .SHIFT_HEATMAP :
288
+ output_flipped [:, :, :, 1 :] = \
289
+ output_flipped .clone ()[:, :, :, 0 :- 1 ]
290
+ # output_flipped[:, :, :, 0] = 0
291
+
292
+ output = (output + output_flipped ) * 0.5
293
+
294
+ # measure elapsed time
295
+ batch_time .update (time .time () - end )
296
+ end = time .time ()
297
+
298
+ c = meta ['center' ].numpy ()
299
+ s = meta ['scale' ].numpy ()
300
+ score = meta ['score' ].numpy ()
301
+ tlwhs = meta ['bbox_tlwh' ].numpy ()
302
+ output = output .data .cpu ()
303
+
304
+ preds , maxvals = get_final_preds (config , output .numpy (), c , s )
305
+
306
+ all_preds [idx :idx + num_images , :, 0 :2 ] = preds [:, :, 0 :2 ]
307
+ all_preds [idx :idx + num_images , :, 2 :3 ] = maxvals
308
+ # double check this all_boxes parts
309
+ all_boxes [idx :idx + num_images , 0 :4 ] = tlwhs
310
+ all_boxes [idx :idx + num_images , 4 ] = score
311
+ all_image_pathes .extend (meta ['image' ])
312
+ if config .DATASET .DATASET == 'mot' :
313
+ seq_names , frame_ids = meta ['image_id' ]
314
+ frame_ids = frame_ids .numpy ().astype (int )
315
+ all_image_ids .extend (list (zip (seq_names , frame_ids )))
316
+
317
+ idx += num_images
318
+
319
+ if i % config .PRINT_FREQ == 0 :
320
+ msg = 'Test: [{0}/{1}]\t ' \
321
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t ' .format (
322
+ i , len (image_loader ), batch_time = batch_time )
323
+ logger .info (msg )
324
+
325
+ prefix = '{}_{}' .format (os .path .join (output_dir , 'inference' ), i )
326
+ pred , _ = get_max_preds (output .numpy ())
327
+ save_debug_images (config , input , meta , target , pred * 4 , output , prefix )
328
+
329
+ # write output
330
+ frame_results = defaultdict (list )
331
+ for image_id , pred , box in zip (all_image_ids , all_preds , all_boxes ):
332
+ frame_results [image_id ].append ((pred .astype (float ).tolist (), box .astype (float ).tolist ()))
333
+
334
+ final_results = {}
335
+ for image_id , results in frame_results .items ():
336
+ keypoints , boxes = zip (* results )
337
+ final_results [image_id ] = {'keypoints' : keypoints , 'boxes' : boxes }
338
+
339
+ if not os .path .isdir (output_dir ):
340
+ os .makedirs (output_dir )
341
+ with open (os .path .join (output_dir , 'box_keypoints.json' ), 'w' ) as f :
342
+ json .dump (final_results , f )
343
+ logger .info ('Save results to {}' .format (os .path .join (output_dir , 'box_keypoints.json' )))
344
+
345
+
243
346
# markdown format output
244
347
def _print_name_value (name_value , full_arch_name ):
245
348
names = name_value .keys ()
0 commit comments