Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c78ae33
feat: 🚀 Added Non-Maximum Merging to Detections
Oct 13, 2023
57b12e6
Added __setitem__ to Detections and refactored the object prediction …
Oct 18, 2023
9f22273
Added standard full image inference after sliced inference to increas…
Oct 18, 2023
6f47046
Refactored merging of Detection attributes to better work with np.nda…
Oct 18, 2023
5f0dcc2
Merge branch 'develop' into add_nmm_to_detections to resolve conflicts
Apr 9, 2024
166a8da
Implement Feedback
Apr 11, 2024
b159873
Merge remote-tracking branch 'upstream/develop' into add_nmm_to_detec…
May 6, 2024
d7e52be
NMM: Add None-checks, fix area normalization, style
May 6, 2024
bee3252
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 6, 2024
97c4071
NMM: Move detections merge into Detections class.
May 6, 2024
204669b
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 6, 2024
2eb0c7c
Merge remote-tracking branch 'upstream/develop' into add_nmm_to_detec…
LinasKo May 14, 2024
c3b77d0
Rename, remove functions, unit-test & change `merge_object_detection_…
May 14, 2024
8014e88
Test box_non_max_merge
May 14, 2024
26bafec
Test box_non_max_merge, rename threshold,to __init__
May 15, 2024
d2d50fb
renamed bbox -> xyxy
May 15, 2024
2d740bd
fix: merge_object_detection_pair
May 15, 2024
145b5fe
Rename to batch_box_non_max_merge to box_non_max_merge_batch
May 15, 2024
6c40935
box_non_max_merge: use our functions to compute iou
May 15, 2024
53f345e
Minor renaming
May 15, 2024
0e2eec0
Revert np.bool comparisons with `is`
May 15, 2024
559ef90
Simplify box_non_max_merge
May 15, 2024
f8f3647
Removed suprplus NMM code for 20% speedup
May 15, 2024
9024396
Add npt.NDarray[x] types, remove resolution_wh default val
May 17, 2024
6fbca83
Address review comments, simplify merge
May 23, 2024
db1b473
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 23, 2024
0721bc2
Remove _set_at_index
May 23, 2024
530e1d0
Address comments
May 27, 2024
2ee9e08
Renamed to group_overlapping_boxes
May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: 🚀 Added Non-Maximum Merging to Detections
  • Loading branch information
mario-dg committed Oct 13, 2023
commit c78ae33e43c95e067e2ae34ff9e7616fe696cac3
107 changes: 107 additions & 0 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
import numpy as np

from supervision.detection.utils import (
batched_greedy_nmm,
box_iou_batch,
extract_ultralytics_masks,
get_merged_bbox,
get_merged_class_id,
get_merged_confidence,
get_merged_mask,
get_merged_tracker_id,
greedy_nmm,
non_max_suppression,
process_roboflow_result,
xywh_to_xyxy,
Expand Down Expand Up @@ -729,6 +737,105 @@ def box_area(self) -> np.ndarray:
"""
return (self.xyxy[:, 3] - self.xyxy[:, 1]) * (self.xyxy[:, 2] - self.xyxy[:, 0])

def with_nmm(
self, threshold: float = 0.5, class_agnostic: bool = False
) -> Detections:
"""
Perform non-maximum merging on the current set of object detections.

Args:
threshold (float, optional): The intersection-over-union threshold
to use for non-maximum merging. Defaults to 0.5.
class_agnostic (bool, optional): Whether to perform class-agnostic
non-maximum merging. If True, the class_id of each detection
will be ignored. Defaults to False.

Returns:
Detections: A new Detections object containing the subset of detections
after non-maximum merging.

Raises:
AssertionError: If `confidence` is None and class_agnostic is False.
If `class_id` is None and class_agnostic is False.
"""
if len(self) == 0:
return self

assert (
self.confidence is not None
), "Detections confidence must be given for NMM to be executed."

if class_agnostic:
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
keep_to_merge_list = greedy_nmm(predictions, threshold)
else:
predictions = np.hstack(
(
self.xyxy,
self.confidence.reshape(-1, 1),
self.class_id.reshape(-1, 1),
)
)
keep_to_merge_list = batched_greedy_nmm(predictions, threshold)

result = []

for keep_ind, merge_ind_list in keep_to_merge_list.items():
for merge_ind in merge_ind_list:
if (
box_iou_batch(self[keep_ind].xyxy, self[merge_ind].xyxy).item()
> threshold
):
self[keep_ind].xyxy = np.vstack(
(
self[keep_ind].xyxy,
get_merged_bbox(self.xyxy[keep_ind], self.xyxy[merge_ind]),
)
)
self[keep_ind].class_id = np.hstack(
(
self[keep_ind].class_id,
get_merged_class_id(
self.class_id[keep_ind].item(),
self.class_id[merge_ind].item(),
),
)
)
self[keep_ind].confidence = np.hstack(
(
self[keep_ind].confidence,
get_merged_confidence(
self.confidence[keep_ind].item(),
self.confidence[merge_ind].item(),
),
)
)
if self.mask is not None:
merged_mask = get_merged_mask(
self.mask[keep_ind], self.mask[merge_ind]
)
if self[keep_ind].mask is None:
self[keep_ind].mask = np.array([merged_mask])
else:
self[keep_ind].mask = np.vstack(
(self[keep_ind].mask, merged_mask[np.newaxis])
)
if self.tracker_id is not None:
merged_tracker_id = get_merged_tracker_id(
self.tracker_id[keep_ind].item(),
self.tracker_id[merge_ind].item(),
)
if self[keep_ind].tracker_id is None:
self[keep_ind].tracker_id = np.array(
[merged_tracker_id], dtype=int
)
else:
self[keep_ind].tracker_id = np.hstack(
(self[keep_ind].tracker_id, merged_tracker_id)
)
result.append(self[keep_ind])
return Detections.merge(result)

def with_nms(
self, threshold: float = 0.5, class_agnostic: bool = False
) -> Detections:
Expand Down
17 changes: 14 additions & 3 deletions supervision/detection/tools/inference_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class InferenceSlicer:
slices in the format `(width_ratio, height_ratio)`.
iou_threshold (Optional[float]): Intersection over Union (IoU) threshold
used for non-max suppression.
merge_detections (Optional[bool]): Whether to merge the detection from all
slices or simply concatenate them. If `True`, Non-Maximum Merging (NMM),
otherwise Non-Maximum Suppression (NMS),
is applied to the final detections.
callback (Callable): A function that performs inference on a given image
slice and returns detections.
thread_workers (int): Number of threads for parallel execution.
Expand All @@ -53,11 +57,13 @@ def __init__(
slice_wh: Tuple[int, int] = (320, 320),
overlap_ratio_wh: Tuple[float, float] = (0.2, 0.2),
iou_threshold: Optional[float] = 0.5,
merge_detections: Optional[bool] = False,
thread_workers: int = 1,
):
self.slice_wh = slice_wh
self.overlap_ratio_wh = overlap_ratio_wh
self.iou_threshold = iou_threshold
self.merge_detections = merge_detections
self.callback = callback
self.thread_workers = thread_workers
validate_inference_callback(callback=callback)
Expand Down Expand Up @@ -109,9 +115,14 @@ def __call__(self, image: np.ndarray) -> Detections:
for future in as_completed(futures):
detections_list.append(future.result())

return Detections.merge(detections_list=detections_list).with_nms(
threshold=self.iou_threshold
)
if self.merge_detections:
return Detections.merge(detections_list=detections_list).with_nmm(
threshold=self.iou_threshold
)
else:
return Detections.merge(detections_list=detections_list).with_nms(
threshold=self.iou_threshold
)

def _run_callback(self, image, offset) -> Detections:
"""
Expand Down
190 changes: 189 additions & 1 deletion supervision/detection/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import cv2
import numpy as np
Expand Down Expand Up @@ -110,6 +110,194 @@ def non_max_suppression(
return keep[sort_index.argsort()]


def greedy_nmm(predictions: np.ndarray, threshold: float = 0.5) -> Dict[int, List[int]]:
"""
Apply greedy version of non-maximum merging to avoid detecting too many
overlapping bounding boxes for a given object.

Args:
predictions (np.ndarray): An array of shape `(n, 5)` containing
the bounding boxes coordinates in format `[x1, y1, x2, y2]`
and the confidence scores.
threshold (float, optional): The intersection-over-union threshold
to use for non-maximum suppression. Defaults to 0.5.

Returns:
Dict[int, List[int]]: Mapping from prediction indices
to keep to a list of prediction indices to be merged.
"""
keep_to_merge_list = {}

x1 = predictions[:, 0]
y1 = predictions[:, 1]
x2 = predictions[:, 2]
y2 = predictions[:, 3]

scores = predictions[:, 4]

areas = (x2 - x1) * (y2 - y1)

order = scores.argsort()

keep = []

while len(order) > 0:
idx = order[-1]

keep.append(idx.tolist())

order = order[:-1]

if len(order) == 0:
keep_to_merge_list[idx.tolist()] = []
break

xx1 = np.take(x1, axis=0, indices=order)
xx2 = np.take(x2, axis=0, indices=order)
yy1 = np.take(y1, axis=0, indices=order)
yy2 = np.take(y2, axis=0, indices=order)

xx1 = np.maximum(xx1, x1[idx])
yy1 = np.maximum(yy1, y1[idx])
xx2 = np.minimum(xx2, x2[idx])
yy2 = np.minimum(yy2, y2[idx])

w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)

inter = w * h

rem_areas = np.take(areas, axis=0, indices=order)

union = (rem_areas - inter) + areas[idx]
match_metric_value = inter / union

mask = match_metric_value < threshold
mask = mask.astype(np.uint8)
matched_box_indices = np.flip(order[np.where(mask == 0)[0]])
unmatched_indices = order[np.where(mask == 1)[0]]

order = unmatched_indices[scores[unmatched_indices].argsort()]

keep_to_merge_list[idx.tolist()] = []

for matched_box_ind in matched_box_indices.tolist():
keep_to_merge_list[idx.tolist()].append(matched_box_ind)

return keep_to_merge_list


def batched_greedy_nmm(
predictions: np.ndarray, threshold: float = 0.5
) -> Dict[int, List[int]]:
"""
Apply greedy version of non-maximum merging per category to avoid detecting
too many overlapping bounding boxes for a given object.

Args:
predictions (np.ndarray): An array of shape `(n, 6)` containing
the bounding boxes coordinates in format `[x1, y1, x2, y2]`,
the confidence scores and class_ids.
threshold (float, optional): The intersection-over-union threshold
to use for non-maximum suppression. Defaults to 0.5.

Returns:
Dict[int, List[int]]: Mapping from prediction indices
to keep to a list of prediction indices to be merged.
"""
category_ids = predictions[:, 5]
keep_to_merge_list = {}
for category_id in np.unique(category_ids):
curr_indices = np.where(category_ids == category_id)[0]
curr_keep_to_merge_list = greedy_nmm(predictions[curr_indices], threshold)
curr_indices_list = curr_indices.tolist()
for curr_keep, curr_merge_list in curr_keep_to_merge_list.items():
keep = curr_indices_list[curr_keep]
merge_list = [curr_indices_list[i] for i in curr_merge_list]
keep_to_merge_list[keep] = merge_list
return keep_to_merge_list


def get_merged_bbox(bbox1: np.ndarray, bbox2: np.ndarray) -> np.ndarray:
"""
Merges two bounding boxes into one.

Args:
bbox1 (np.ndarray): A numpy array of shape `(, 4)` where the
row corresponds to a bounding box in
the format `(x_min, y_min, x_max, y_max)`.
bbox2 (np.ndarray): A numpy array of shape `(, 4)` where the
row corresponds to a bounding box in
the format `(x_min, y_min, x_max, y_max)`.

Returns:
np.ndarray: A numpy array of shape `(, 4)` where the new
bounding box is the merged bounding box of `bbox1` and `bbox2`.
"""
left_top = np.minimum(bbox1[:2], bbox2[:2])
right_bottom = np.maximum(bbox1[2:], bbox2[2:])
return np.concatenate([left_top, right_bottom])


def get_merged_class_id(id1: int, id2: int) -> int:
"""
Merges two class ids into one.

Args:
id1 (int): The first class id.
id2 (int): The second class id.

Returns:
int: The merged class id.
"""
return max(id1, id2)


def get_merged_confidence(confidence1: float, confidence2: float) -> float:
"""
Merges two confidences into one.

Args:
confidence1 (float): The first confidence.
confidence2 (float): The second confidence.

Returns:
float: The merged confidence.
"""
return max(confidence1, confidence2)


def get_merged_mask(mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray:
"""
Merges two masks into one.

Args:
mask1 (np.ndarray): A numpy array of shape `(H, W)` where `H` and `W`
are the height and width of the mask, respectively.
mask2 (np.ndarray): A numpy array of shape `(H, W)` where `H` and `W`
are the height and width of the mask, respectively.

Returns:
np.ndarray: A numpy array of shape `(H, W)` where the new mask is the
merged mask of `mask1` and `mask2`.
"""
return np.logical_or(mask1, mask2)


def get_merged_tracker_id(tracker_id1: int, tracker_id2: int) -> int:
"""
Merges two tracker ids into one.

Args:
tracker_id1 (int): The first tracker id.
tracker_id2 (int): The second tracker id.

Returns:
int: The merged tracker id.
"""
return max(tracker_id1, tracker_id2)


def clip_boxes(
boxes_xyxy: np.ndarray, frame_resolution_wh: Tuple[int, int]
) -> np.ndarray:
Expand Down