Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,21 @@ python tools/train.py \
--cfg experiments/coco/hrnet/w32_256x192_adam_lr1e-3.yaml \
```

### Visualization

#### Visualizing predictions on COCO val

```
python visualization/plot_coco.py \
--prediction output/coco/w48_384x288_adam_lr1e-3/results/keypoints_val2017_results_0.json \
--save-path visualization/results

```


<img src="figures\visualization\coco\score_610_id_2685_000000002685.png" height="215"><img src="figures\visualization\coco\score_710_id_153229_000000153229.png" height="215"><img src="figures\visualization\coco\score_755_id_343561_000000343561.png" height="215">

<img src="figures\visualization\coco\score_755_id_559842_000000559842.png" height="209"><img src="figures\visualization\coco\score_770_id_6954_000000006954.png" height="209"><img src="figures\visualization\coco\score_919_id_53626_000000053626.png" height="209">

### Other applications
Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [High-Resolution Networks](https://github.com/HRNet).
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
309 changes: 309 additions & 0 deletions visualization/plot_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Ke Sun ([email protected])
# Modified by Depu Meng ([email protected])
# ------------------------------------------------------------------------------

import argparse
import numpy as np
import matplotlib.pyplot as plt
import cv2
import json
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import os


class ColorStyle:
def __init__(self, color, link_pairs, point_color):
self.color = color
self.link_pairs = link_pairs
self.point_color = point_color

for i in range(len(self.color)):
self.link_pairs[i].append(tuple(np.array(self.color[i])/255.))

self.ring_color = []
for i in range(len(self.point_color)):
self.ring_color.append(tuple(np.array(self.point_color[i])/255.))

# Xiaochu Style
# (R,G,B)
color1 = [(179,0,0),(228,26,28),(255,255,51),
(49,163,84), (0,109,45), (255,255,51),
(240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127),
(217,95,14), (254,153,41),(255,255,51),
(44,127,184),(0,0,255)]

link_pairs1 = [
[15, 13], [13, 11], [11, 5],
[12, 14], [14, 16], [12, 6],
[3, 1],[1, 2],[1, 0],[0, 2],[2,4],
[9, 7], [7,5], [5, 6],
[6, 8], [8, 10],
]

point_color1 = [(240,2,127),(240,2,127),(240,2,127),
(240,2,127), (240,2,127),
(255,255,51),(255,255,51),
(254,153,41),(44,127,184),
(217,95,14),(0,0,255),
(255,255,51),(255,255,51),(228,26,28),
(49,163,84),(252,176,243),(0,176,240),
(255,255,0),(169, 209, 142),
(255,255,0),(169, 209, 142),
(255,255,0),(169, 209, 142)]

xiaochu_style = ColorStyle(color1, link_pairs1, point_color1)


# Chunhua Style
# (R,G,B)
color2 = [(252,176,243),(252,176,243),(252,176,243),
(0,176,240), (0,176,240), (0,176,240),
(240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127),
(255,255,0), (255,255,0),(169, 209, 142),
(169, 209, 142),(169, 209, 142)]

link_pairs2 = [
[15, 13], [13, 11], [11, 5],
[12, 14], [14, 16], [12, 6],
[3, 1],[1, 2],[1, 0],[0, 2],[2,4],
[9, 7], [7,5], [5, 6], [6, 8], [8, 10],
]

point_color2 = [(240,2,127),(240,2,127),(240,2,127),
(240,2,127), (240,2,127),
(255,255,0),(169, 209, 142),
(255,255,0),(169, 209, 142),
(255,255,0),(169, 209, 142),
(252,176,243),(0,176,240),(252,176,243),
(0,176,240),(252,176,243),(0,176,240),
(255,255,0),(169, 209, 142),
(255,255,0),(169, 209, 142),
(255,255,0),(169, 209, 142)]

chunhua_style = ColorStyle(color2, link_pairs2, point_color2)

def parse_args():
parser = argparse.ArgumentParser(description='Visualize COCO predictions')
# general
parser.add_argument('--image-path',
help='Path of COCO val images',
type=str,
default='data/coco/images/val2017/'
)

parser.add_argument('--gt-anno',
help='Path of COCO val annotation',
type=str,
default='data/coco/annotations/person_keypoints_val2017.json'
)

parser.add_argument('--save-path',
help="Path to save the visualizations",
type=str,
default='visualization/coco/')

parser.add_argument('--prediction',
help="Prediction file to visualize",
type=str,
required=True)

parser.add_argument('--style',
help="Style of the visualization: Chunhua style or Xiaochu style",
type=str,
default='chunhua')

args = parser.parse_args()

return args


def map_joint_dict(joints):
joints_dict = {}
for i in range(joints.shape[0]):
x = int(joints[i][0])
y = int(joints[i][1])
id = i
joints_dict[id] = (x, y)

return joints_dict

def plot(data, gt_file, img_path, save_path,
link_pairs, ring_color, save=True):

# joints
coco = COCO(gt_file)
coco_dt = coco.loadRes(data)
coco_eval = COCOeval(coco, coco_dt, 'keypoints')
coco_eval._prepare()
gts_ = coco_eval._gts
dts_ = coco_eval._dts

p = coco_eval.params
p.imgIds = list(np.unique(p.imgIds))
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)

# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]
threshold = 0.3
joint_thres = 0.2
for catId in catIds:
for imgId in p.imgIds[:5000]:
# dimention here should be Nxm
gts = gts_[imgId, catId]
dts = dts_[imgId, catId]
inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
dts = [dts[i] for i in inds]
if len(dts) > p.maxDets[-1]:
dts = dts[0:p.maxDets[-1]]
if len(gts) == 0 or len(dts) == 0:
continue

sum_score = 0
num_box = 0
img_name = str(imgId).zfill(12)

# Read Images
img_file = img_path + img_name + '.jpg'
data_numpy = cv2.imread(img_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
h = data_numpy.shape[0]
w = data_numpy.shape[1]

# Plot
fig = plt.figure(figsize=(w/100, h/100), dpi=100)
ax = plt.subplot(1,1,1)
bk = plt.imshow(data_numpy[:,:,::-1])
bk.set_zorder(-1)
print(img_name)
for j, gt in enumerate(gts):
# matching dt_box and gt_box
bb = gt['bbox']
x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2

# create bounds for ignore regions(double the gt bbox)
g = np.array(gt['keypoints'])
#xg = g[0::3]; yg = g[1::3];
vg = g[2::3]

for i, dt in enumerate(dts):
# Calculate IoU
dt_bb = dt['bbox']
dt_x0 = dt_bb[0] - dt_bb[2]; dt_x1 = dt_bb[0] + dt_bb[2] * 2
dt_y0 = dt_bb[1] - dt_bb[3]; dt_y1 = dt_bb[1] + dt_bb[3] * 2

ol_x = min(x1, dt_x1) - max(x0, dt_x0)
ol_y = min(y1, dt_y1) - max(y0, dt_y0)
ol_area = ol_x * ol_y
s_x = max(x1, dt_x1) - min(x0, dt_x0)
s_y = max(y1, dt_y1) - min(y0, dt_y0)
sum_area = s_x * s_y
iou = ol_area / (sum_area + np.spacing(1))
score = dt['score']

if iou < 0.1 or score < threshold:
continue
else:
print('iou: ', iou)
dt_w = dt_x1 - dt_x0
dt_h = dt_y1 - dt_y0
ref = min(dt_w, dt_h)
num_box += 1
sum_score += dt['score']
dt_joints = np.array(dt['keypoints']).reshape(17,-1)
joints_dict = map_joint_dict(dt_joints)

# stick
for k, link_pair in enumerate(link_pairs):
if link_pair[0] in joints_dict \
and link_pair[1] in joints_dict:
if dt_joints[link_pair[0],2] < joint_thres \
or dt_joints[link_pair[1],2] < joint_thres \
or vg[link_pair[0]] == 0 \
or vg[link_pair[1]] == 0:
continue
if k in range(6,11):
lw = 1
else:
lw = ref / 100.
line = mlines.Line2D(
np.array([joints_dict[link_pair[0]][0],
joints_dict[link_pair[1]][0]]),
np.array([joints_dict[link_pair[0]][1],
joints_dict[link_pair[1]][1]]),
ls='-', lw=lw, alpha=1, color=link_pair[2],)
line.set_zorder(0)
ax.add_line(line)
# black ring
for k in range(dt_joints.shape[0]):
if dt_joints[k,2] < joint_thres \
or vg[link_pair[0]] == 0 \
or vg[link_pair[1]] == 0:
continue
if dt_joints[k,0] > w or dt_joints[k,1] > h:
continue
if k in range(5):
radius = 1
else:
radius = ref / 100

circle = mpatches.Circle(tuple(dt_joints[k,:2]),
radius=radius,
ec='black',
fc=ring_color[k],
alpha=1,
linewidth=1)
circle.set_zorder(1)
ax.add_patch(circle)

avg_score = (sum_score / (num_box+np.spacing(1)))*1000

plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.axis('off')
plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
plt.margins(0,0)
if save:
plt.savefig(save_path + \
'score_'+str(np.int(avg_score))+ \
'_id_'+str(imgId)+ \
'_'+img_name + '.png',
format='png', bbox_inckes='tight', dpi=100)
plt.savefig(save_path +'id_'+str(imgId)+ '.pdf', format='pdf',
bbox_inckes='tight', dpi=100)
# plt.show()
plt.close()

if __name__ == '__main__':

args = parse_args()
if args.style == 'xiaochu':
# Xiaochu Style
colorstyle = xiaochu_style
elif args.style == 'chunhua':
# Chunhua Style
colorstyle = chunhua_style
else:
raise Exception('Invalid color style')

save_path = args.save_path
img_path = args.image_path
if not os.path.exists(save_path):
try:
os.makedirs(save_path)
except Exception:
print('Fail to make {}'.format(save_path))


with open(args.prediction) as f:
data = json.load(f)
gt_file = args.gt_anno
plot(data, gt_file, img_path, save_path, colorstyle.link_pairs, colorstyle.ring_color, save=True)