Skip to content

Commit 559ef90

Browse files
author
Linas Kondrackis
committed
Simplify box_non_max_merge
1 parent 0e2eec0 commit 559ef90

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

supervision/detection/utils.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def box_non_max_merge(
292292
Dict[int, List[int]]: Mapping from prediction indices
293293
to keep to a list of prediction indices to be merged.
294294
"""
295-
keep_to_merge_list = {}
295+
keep_to_merge_list: Dict[int, List[int]] = {}
296296

297297
scores = predictions[:, 4]
298298
order = scores.argsort()
@@ -307,17 +307,11 @@ def box_non_max_merge(
307307
break
308308

309309
ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4])
310+
ious = ious.flatten()
310311

311-
below_threshold = (ious < iou_threshold).astype(np.uint8)
312-
matched_box_indices = np.flip(order[np.where(below_threshold == 0)[0]])
313-
unmatched_indices = order[np.where(below_threshold == 1)[0]]
314-
315-
order = unmatched_indices[scores[unmatched_indices].argsort()]
316-
317-
keep_to_merge_list[idx.tolist()] = []
318-
319-
for matched_box_ind in matched_box_indices.tolist():
320-
keep_to_merge_list[idx.tolist()].append(matched_box_ind)
312+
above_threshold = ious >= iou_threshold
313+
keep_to_merge_list[idx] = np.flip(order[above_threshold]).tolist()
314+
order = order[~above_threshold]
321315

322316
return keep_to_merge_list
323317

0 commit comments

Comments
 (0)