Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c78ae33
feat: 🚀 Added Non-Maximum Merging to Detections
Oct 13, 2023
57b12e6
Added __setitem__ to Detections and refactored the object prediction …
Oct 18, 2023
9f22273
Added standard full image inference after sliced inference to increas…
Oct 18, 2023
6f47046
Refactored merging of Detection attributes to better work with np.nda…
Oct 18, 2023
5f0dcc2
Merge branch 'develop' into add_nmm_to_detections to resolve conflicts
Apr 9, 2024
166a8da
Implement Feedback
Apr 11, 2024
b159873
Merge remote-tracking branch 'upstream/develop' into add_nmm_to_detec…
May 6, 2024
d7e52be
NMM: Add None-checks, fix area normalization, style
May 6, 2024
bee3252
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 6, 2024
97c4071
NMM: Move detections merge into Detections class.
May 6, 2024
204669b
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 6, 2024
2eb0c7c
Merge remote-tracking branch 'upstream/develop' into add_nmm_to_detec…
LinasKo May 14, 2024
c3b77d0
Rename, remove functions, unit-test & change `merge_object_detection_…
May 14, 2024
8014e88
Test box_non_max_merge
May 14, 2024
26bafec
Test box_non_max_merge, rename threshold,to __init__
May 15, 2024
d2d50fb
renamed bbox -> xyxy
May 15, 2024
2d740bd
fix: merge_object_detection_pair
May 15, 2024
145b5fe
Rename to batch_box_non_max_merge to box_non_max_merge_batch
May 15, 2024
6c40935
box_non_max_merge: use our functions to compute iou
May 15, 2024
53f345e
Minor renaming
May 15, 2024
0e2eec0
Revert np.bool comparisons with `is`
May 15, 2024
559ef90
Simplify box_non_max_merge
May 15, 2024
f8f3647
Removed suprplus NMM code for 20% speedup
May 15, 2024
9024396
Add npt.NDarray[x] types, remove resolution_wh default val
May 17, 2024
6fbca83
Address review comments, simplify merge
May 23, 2024
db1b473
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 23, 2024
0721bc2
Remove _set_at_index
May 23, 2024
530e1d0
Address comments
May 27, 2024
2ee9e08
Renamed to group_overlapping_boxes
May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
import numpy as np

from supervision.detection.utils import (
batched_greedy_nmm,
box_iou_batch,
extract_ultralytics_masks,
get_merged_bbox,
get_merged_class_id,
get_merged_confidence,
get_merged_mask,
get_merged_tracker_id,
greedy_nmm,
non_max_suppression,
process_roboflow_result,
xywh_to_xyxy,
Expand Down Expand Up @@ -59,6 +67,27 @@ def _validate_tracker_id(tracker_id: Any, n: int) -> None:
raise ValueError("tracker_id must be None or 1d np.ndarray with (n,) shape")


def _merge_object_detection_pair(pred1: Detections, pred2: Detections) -> Detections:
merged_bbox = get_merged_bbox(pred1.xyxy, pred2.xyxy)
merged_conf = get_merged_confidence(pred1.confidence, pred2.confidence)
merged_class_id = get_merged_class_id(pred1.class_id, pred2.class_id)
merged_tracker_id = None
merged_mask = None

if pred1.mask and pred2.mask:
merged_mask = get_merged_mask(pred1.mask, pred2.mask)
if pred1.tracker_id and pred2.tracker_id:
merged_tracker_id = get_merged_tracker_id(pred1.tracker_id, pred2.tracker_id)

return Detections(
xyxy=merged_bbox,
mask=merged_mask,
confidence=merged_conf,
class_id=merged_class_id,
tracker_id=merged_tracker_id,
)


@dataclass
class Detections:
"""
Expand Down Expand Up @@ -660,6 +689,38 @@ def get_anchor_coordinates(self, anchor: Position) -> np.ndarray:

raise ValueError(f"{anchor} is not supported.")

def __setitem__(
self, index: Union[int, slice, List[int], np.ndarray], value: Detections
) -> None:
"""
Set a subset of the Detections object.

Args:
index (Union[int, slice, List[int], np.ndarray]):
The index or indices of the subset of the Detections
value (Detections): The new value of the subset of the Detections

Example:
```python
>>> import supervision as sv

>>> detections = sv.Detections(...)

>>> detections[0] = sv.Detections(...)
```
"""
if isinstance(index, int):
index = [index]
self.xyxy[index] = value.xyxy
if self.mask is not None:
self.mask[index] = value.mask
if self.confidence is not None:
self.confidence[index] = value.confidence
if self.class_id is not None:
self.class_id[index] = value.class_id
if self.tracker_id is not None:
self.tracker_id[index] = value.tracker_id

def __getitem__(
self, index: Union[int, slice, List[int], np.ndarray]
) -> Detections:
Expand Down Expand Up @@ -729,6 +790,64 @@ def box_area(self) -> np.ndarray:
"""
return (self.xyxy[:, 3] - self.xyxy[:, 1]) * (self.xyxy[:, 2] - self.xyxy[:, 0])

def with_nmm(
self, threshold: float = 0.5, class_agnostic: bool = False
) -> Detections:
"""
Perform non-maximum merging on the current set of object detections.

Args:
threshold (float, optional): The intersection-over-union threshold
to use for non-maximum merging. Defaults to 0.5.
class_agnostic (bool, optional): Whether to perform class-agnostic
non-maximum merging. If True, the class_id of each detection
will be ignored. Defaults to False.

Returns:
Detections: A new Detections object containing the subset of detections
after non-maximum merging.

Raises:
AssertionError: If `confidence` is None and class_agnostic is False.
If `class_id` is None and class_agnostic is False.
"""
if len(self) == 0:
return self

assert 0.0 <= threshold <= 1.0, "Threshold must be between 0 and 1."

assert (
self.confidence is not None
), "Detections confidence must be given for NMM to be executed."

if class_agnostic:
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
keep_to_merge_list = greedy_nmm(predictions, threshold)
else:
predictions = np.hstack(
(
self.xyxy,
self.confidence.reshape(-1, 1),
self.class_id.reshape(-1, 1),
)
)
keep_to_merge_list = batched_greedy_nmm(predictions, threshold)

result = []

for keep_ind, merge_ind_list in keep_to_merge_list.items():
for merge_ind in merge_ind_list:
if (
box_iou_batch(self[keep_ind].xyxy, self[merge_ind].xyxy).item()
> threshold
):
self[keep_ind] = _merge_object_detection_pair(
self[keep_ind], self[merge_ind]
)
result.append(self[keep_ind])

return Detections.merge(result)

def with_nms(
self, threshold: float = 0.5, class_agnostic: bool = False
) -> Detections:
Expand Down
24 changes: 21 additions & 3 deletions supervision/detection/tools/inference_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class InferenceSlicer:
slices in the format `(width_ratio, height_ratio)`.
iou_threshold (Optional[float]): Intersection over Union (IoU) threshold
used for non-max suppression.
merge_detections (Optional[bool]): Whether to merge the detection from all
slices or simply concatenate them. If `True`, Non-Maximum Merging (NMM),
otherwise Non-Maximum Suppression (NMS), is applied to the detections.
perform_standard_pred (Optional[bool]): Whether to perform inference on the
whole image in addition to the slices to increase the accuracy of
large object detection.
callback (Callable): A function that performs inference on a given image
slice and returns detections.
thread_workers (int): Number of threads for parallel execution.
Expand All @@ -53,11 +59,15 @@ def __init__(
slice_wh: Tuple[int, int] = (320, 320),
overlap_ratio_wh: Tuple[float, float] = (0.2, 0.2),
iou_threshold: Optional[float] = 0.5,
merge_detections: Optional[bool] = False,
perform_standard_pred: Optional[bool] = False,
thread_workers: int = 1,
):
self.slice_wh = slice_wh
self.overlap_ratio_wh = overlap_ratio_wh
self.iou_threshold = iou_threshold
self.merge_detections = merge_detections
self.perform_standard_pred = perform_standard_pred
self.callback = callback
self.thread_workers = thread_workers
validate_inference_callback(callback=callback)
Expand Down Expand Up @@ -109,9 +119,17 @@ def __call__(self, image: np.ndarray) -> Detections:
for future in as_completed(futures):
detections_list.append(future.result())

return Detections.merge(detections_list=detections_list).with_nms(
threshold=self.iou_threshold
)
if self.perform_standard_pred:
detections_list.append(self.callback(image))

if self.merge_detections:
return Detections.merge(detections_list=detections_list).with_nmm(
threshold=self.iou_threshold
)
else:
return Detections.merge(detections_list=detections_list).with_nms(
threshold=self.iou_threshold
)

def _run_callback(self, image, offset) -> Detections:
"""
Expand Down
Loading