Skip to content

Commit 22e04d1

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
allow visualizer to work without metadata
Differential Revision: D18021490 fbshipit-source-id: 3d2208914ba9cb8c74b1de48185dbef5909f6977
1 parent 20bef23 commit 22e04d1

File tree

4 files changed

+28
-11
lines changed

4 files changed

+28
-11
lines changed

demo/predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
2121
parallel (bool): whether to run the model in different processes from visualization.
2222
Useful since the visualization logic can be slow.
2323
"""
24-
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
24+
self.metadata = MetadataCatalog.get(
25+
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
26+
)
2527
self.cpu_device = torch.device("cpu")
2628
self.instance_mode = instance_mode
2729

detectron2/utils/video_visualizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def draw_instance_predictions(self, frame, predictions):
8787
]
8888
colors = self._assign_colors(detected)
8989

90-
labels = _create_text_labels(classes, scores, self.metadata.thing_classes)
90+
labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
9191

9292
if self._instance_mode == ColorMode.IMAGE_BW:
9393
# any() returns uint8 tensor

detectron2/utils/visualizer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -686,20 +686,23 @@ def draw_and_connect_keypoints(self, keypoints):
686686
output (VisImage): image object with visualizations.
687687
"""
688688
visible = {}
689+
keypoint_names = self.metadata.get("keypoint_names")
689690
for idx, keypoint in enumerate(keypoints):
690691
# draw keypoint
691692
x, y, prob = keypoint
692693
if prob > _KEYPOINT_THRESHOLD:
693694
self.draw_circle((x, y), color=_RED)
694-
keypoint_name = self.metadata.keypoint_names[idx]
695-
visible[keypoint_name] = (x, y)
696-
697-
for kp0, kp1, color in self.metadata.keypoint_connection_rules:
698-
if kp0 in visible and kp1 in visible:
699-
x0, y0 = visible[kp0]
700-
x1, y1 = visible[kp1]
701-
color = tuple(x / 255.0 for x in color)
702-
self.draw_line([x0, x1], [y0, y1], color=color)
695+
if keypoint_names:
696+
keypoint_name = keypoint_names[idx]
697+
visible[keypoint_name] = (x, y)
698+
699+
if self.metadata.get("keypoint_connection_rules"):
700+
for kp0, kp1, color in self.metadata.keypoint_connection_rules:
701+
if kp0 in visible and kp1 in visible:
702+
x0, y0 = visible[kp0]
703+
x1, y1 = visible[kp1]
704+
color = tuple(x / 255.0 for x in color)
705+
self.draw_line([x0, x1], [y0, y1], color=color)
703706

704707
# draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
705708
# Note that this strategy is specific to person keypoints.

tests/test_visualizer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,15 @@ def test_overlay_rotated_instances(self):
102102
v = Visualizer(img, self.metadata)
103103
output = v.overlay_instances(boxes=rotated_boxes, labels=labels).get_image()
104104
self.assertEqual(output.shape, img.shape)
105+
106+
def test_draw_no_metadata(self):
107+
img, boxes, _, _, masks = self._random_data()
108+
num_inst = len(boxes)
109+
inst = Instances((img.shape[0], img.shape[1]))
110+
inst.pred_classes = torch.randint(0, 80, size=(num_inst,))
111+
inst.scores = torch.rand(num_inst)
112+
inst.pred_boxes = torch.from_numpy(boxes)
113+
inst.pred_masks = torch.from_numpy(np.asarray(masks))
114+
115+
v = Visualizer(img, MetadataCatalog.get("asdfasdf"))
116+
v.draw_instance_predictions(inst)

0 commit comments

Comments
 (0)