Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c78ae33
feat: 🚀 Added Non-Maximum Merging to Detections
Oct 13, 2023
57b12e6
Added __setitem__ to Detections and refactored the object prediction …
Oct 18, 2023
9f22273
Added standard full image inference after sliced inference to increas…
Oct 18, 2023
6f47046
Refactored merging of Detection attributes to better work with np.nda…
Oct 18, 2023
5f0dcc2
Merge branch 'develop' into add_nmm_to_detections to resolve conflicts
Apr 9, 2024
166a8da
Implement Feedback
Apr 11, 2024
b159873
Merge remote-tracking branch 'upstream/develop' into add_nmm_to_detec…
May 6, 2024
d7e52be
NMM: Add None-checks, fix area normalization, style
May 6, 2024
bee3252
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 6, 2024
97c4071
NMM: Move detections merge into Detections class.
May 6, 2024
204669b
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 6, 2024
2eb0c7c
Merge remote-tracking branch 'upstream/develop' into add_nmm_to_detec…
LinasKo May 14, 2024
c3b77d0
Rename, remove functions, unit-test & change `merge_object_detection_…
May 14, 2024
8014e88
Test box_non_max_merge
May 14, 2024
26bafec
Test box_non_max_merge, rename threshold,to __init__
May 15, 2024
d2d50fb
renamed bbox -> xyxy
May 15, 2024
2d740bd
fix: merge_object_detection_pair
May 15, 2024
145b5fe
Rename to batch_box_non_max_merge to box_non_max_merge_batch
May 15, 2024
6c40935
box_non_max_merge: use our functions to compute iou
May 15, 2024
53f345e
Minor renaming
May 15, 2024
0e2eec0
Revert np.bool comparisons with `is`
May 15, 2024
559ef90
Simplify box_non_max_merge
May 15, 2024
f8f3647
Removed suprplus NMM code for 20% speedup
May 15, 2024
9024396
Add npt.NDarray[x] types, remove resolution_wh default val
May 17, 2024
6fbca83
Address review comments, simplify merge
May 23, 2024
db1b473
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 23, 2024
0721bc2
Remove _set_at_index
May 23, 2024
530e1d0
Address comments
May 27, 2024
2ee9e08
Renamed to group_overlapping_boxes
May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Implement Feedback
  • Loading branch information
mario-dg committed Apr 11, 2024
commit 166a8da9a07b20852c4559624fe029fc87bc8751
154 changes: 95 additions & 59 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,17 @@

from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES
from supervision.detection.utils import (
batched_greedy_nmm,
batch_non_max_merge,
box_iou_batch,
box_non_max_suppression,
calculate_masks_centroids,
extract_ultralytics_masks,
get_data_item,
get_merged_bbox,
get_merged_class_id,
get_merged_confidence,
get_merged_mask,
get_merged_tracker_id,
greedy_nmm,
is_data_equal,
mask_non_max_suppression,
mask_to_xyxy,
merge_data,
non_max_merge,
process_roboflow_result,
validate_detections_fields,
xywh_to_xyxy,
Expand All @@ -32,24 +27,65 @@
from supervision.utils.internal import deprecated


def _merge_object_detection_pair(pred1: Detections, pred2: Detections) -> Detections:
merged_bbox = get_merged_bbox(pred1.xyxy, pred2.xyxy)
merged_conf = get_merged_confidence(pred1.confidence, pred2.confidence)
merged_class_id = get_merged_class_id(pred1.class_id, pred2.class_id)
def _merge_object_detection_pair(det1: Detections, det2: Detections) -> Detections:
"""
Merges two Detections object into a single Detections object.

A `winning` detection is determined based on the confidence score of the two
input detections. This winning detection is then used to specify which `class_id`,
`tracker_id`, and `data` to include in the merged Detections object.
The resulting `confidence` of the merged object is calculated by the weighted
contribution of each detection to the merged object.
The bounding boxes and masks of the two input detections are merged into a single
bounding box and mask, respectively.

Args:
det1 (Detections):
The first Detections object
det2 (Detections):
The second Detections object

Returns:
Detections: A new Detections object, with merged attributes.
"""
assert (
len(det1) == len(det2) == 1
), "Both Detections should have exactly 1 detected object."
winning_det = det1 if det1.confidence.item() > det2.confidence.item() else det2

area_det1 = (det1.xyxy[0][2] - det1.xyxy[0][0]) * (
det1.xyxy[0][3] - det1.xyxy[0][1]
)
area_det2 = (det2.xyxy[0][2] - det2.xyxy[0][0]) * (
det2.xyxy[0][3] - det2.xyxy[0][1]
)
merged_x1, merged_y1 = np.minimum(det1.xyxy[0][:2], det2.xyxy[0][:2])
merged_x2, merged_y2 = np.maximum(det1.xyxy[0][2:], det2.xyxy[0][2:])
merged_area = (merged_x2 - merged_x1) * (merged_y2 - merged_y1)

merged_conf = (
area_det1 * det1.confidence.item() + area_det2 * det2.confidence.item()
) / merged_area
merged_bbox = [np.concatenate([merged_x1, merged_y1, merged_x2, merged_y2])]
merged_class_id = winning_det.class_id.item()
merged_tracker_id = None
merged_mask = None
merged_data = None

if pred1.mask and pred2.mask:
merged_mask = get_merged_mask(pred1.mask, pred2.mask)
if pred1.tracker_id and pred2.tracker_id:
merged_tracker_id = get_merged_tracker_id(pred1.tracker_id, pred2.tracker_id)
if det1.mask and det2.mask:
merged_mask = np.logical_or(det1.mask, det2.mask)
if det1.tracker_id and det2.tracker_id:
merged_tracker_id = winning_det.tracker_id.item()
if det1.data and det2.data:
merged_data = winning_det.data

return Detections(
xyxy=merged_bbox,
mask=merged_mask,
confidence=merged_conf,
class_id=merged_class_id,
tracker_id=merged_tracker_id,
data=merged_data,
)


Expand Down Expand Up @@ -1091,22 +1127,24 @@ def box_area(self) -> np.ndarray:
"""
return (self.xyxy[:, 3] - self.xyxy[:, 1]) * (self.xyxy[:, 2] - self.xyxy[:, 0])

def with_nmm(
def with_nms(
self, threshold: float = 0.5, class_agnostic: bool = False
) -> Detections:
"""
Perform non-maximum merging on the current set of object detections.
Performs non-max suppression on detection set. If the detections result
from a segmentation model, the IoU mask is applied. Otherwise, box IoU is used.

Args:
threshold (float, optional): The intersection-over-union threshold
to use for non-maximum merging. Defaults to 0.5.
to use for non-maximum suppression. I'm the lower the value the more
restrictive the NMS becomes. Defaults to 0.5.
class_agnostic (bool, optional): Whether to perform class-agnostic
non-maximum merging. If True, the class_id of each detection
non-maximum suppression. If True, the class_id of each detection
will be ignored. Defaults to False.

Returns:
Detections: A new Detections object containing the subset of detections
after non-maximum merging.
after non-maximum suppression.

Raises:
AssertionError: If `confidence` is None and class_agnostic is False.
Expand All @@ -1115,58 +1153,52 @@ def with_nmm(
if len(self) == 0:
return self

assert 0.0 <= threshold <= 1.0, "Threshold must be between 0 and 1."

assert (
self.confidence is not None
), "Detections confidence must be given for NMM to be executed."
), "Detections confidence must be given for NMS to be executed."

if class_agnostic:
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
keep_to_merge_list = greedy_nmm(predictions, threshold)
else:
assert self.class_id is not None, (
"Detections class_id must be given for NMS to be executed. If you"
" intended to perform class agnostic NMS set class_agnostic=True."
)
predictions = np.hstack(
(
self.xyxy,
self.confidence.reshape(-1, 1),
self.class_id.reshape(-1, 1),
)
)
keep_to_merge_list = batched_greedy_nmm(predictions, threshold)

result = []

for keep_ind, merge_ind_list in keep_to_merge_list.items():
for merge_ind in merge_ind_list:
if (
box_iou_batch(self[keep_ind].xyxy, self[merge_ind].xyxy).item()
> threshold
):
self[keep_ind] = _merge_object_detection_pair(
self[keep_ind], self[merge_ind]
)
result.append(self[keep_ind])
if self.mask is not None:
indices = mask_non_max_suppression(
predictions=predictions, masks=self.mask, iou_threshold=threshold
)
else:
indices = box_non_max_suppression(
predictions=predictions, iou_threshold=threshold
)

return Detections.merge(result)
return self[indices]

def with_nms(
def with_nmm(
self, threshold: float = 0.5, class_agnostic: bool = False
) -> Detections:
"""
Performs non-max suppression on detection set. If the detections result
from a segmentation model, the IoU mask is applied. Otherwise, box IoU is used.
Perform non-maximum merging on the current set of object detections.

Args:
threshold (float, optional): The intersection-over-union threshold
to use for non-maximum suppression. I'm the lower the value the more
restrictive the NMS becomes. Defaults to 0.5.
to use for non-maximum merging. Defaults to 0.5.
class_agnostic (bool, optional): Whether to perform class-agnostic
non-maximum suppression. If True, the class_id of each detection
non-maximum merging. If True, the class_id of each detection
will be ignored. Defaults to False.

Returns:
Detections: A new Detections object containing the subset of detections
after non-maximum suppression.
after non-maximum merging.

Raises:
AssertionError: If `confidence` is None and class_agnostic is False.
Expand All @@ -1175,32 +1207,36 @@ def with_nms(
if len(self) == 0:
return self

assert 0.0 <= threshold <= 1.0, "Threshold must be between 0 and 1."

assert (
self.confidence is not None
), "Detections confidence must be given for NMS to be executed."
), "Detections confidence must be given for NMM to be executed."

if class_agnostic:
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
keep_to_merge_list = non_max_merge(predictions, threshold)
else:
assert self.class_id is not None, (
"Detections class_id must be given for NMS to be executed. If you"
" intended to perform class agnostic NMS set class_agnostic=True."
)
predictions = np.hstack(
(
self.xyxy,
self.confidence.reshape(-1, 1),
self.class_id.reshape(-1, 1),
)
)
keep_to_merge_list = batch_non_max_merge(predictions, threshold)

if self.mask is not None:
indices = mask_non_max_suppression(
predictions=predictions, masks=self.mask, iou_threshold=threshold
)
else:
indices = box_non_max_suppression(
predictions=predictions, iou_threshold=threshold
)
result = []

return self[indices]
for keep_ind, merge_ind_list in keep_to_merge_list.items():
for merge_ind in merge_ind_list:
if (
box_iou_batch(self[keep_ind].xyxy, self[merge_ind].xyxy).item()
> threshold
):
self[keep_ind] = _merge_object_detection_pair(
self[keep_ind], self[merge_ind]
)
result.append(self[keep_ind])

return Detections.merge(result)
24 changes: 3 additions & 21 deletions supervision/detection/tools/inference_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ class InferenceSlicer:
slices in the format `(width_ratio, height_ratio)`.
iou_threshold (Optional[float]): Intersection over Union (IoU) threshold
used for non-max suppression.
merge_detections (Optional[bool]): Whether to merge the detection from all
slices or simply concatenate them. If `True`, Non-Maximum Merging (NMM),
otherwise Non-Maximum Suppression (NMS), is applied to the detections.
perform_standard_pred (Optional[bool]): Whether to perform inference on the
whole image in addition to the slices to increase the accuracy of
large object detection.
callback (Callable): A function that performs inference on a given image
slice and returns detections.
thread_workers (int): Number of threads for parallel execution.
Expand All @@ -59,15 +53,11 @@ def __init__(
slice_wh: Tuple[int, int] = (320, 320),
overlap_ratio_wh: Tuple[float, float] = (0.2, 0.2),
iou_threshold: Optional[float] = 0.5,
merge_detections: Optional[bool] = False,
perform_standard_pred: Optional[bool] = False,
thread_workers: int = 1,
):
self.slice_wh = slice_wh
self.overlap_ratio_wh = overlap_ratio_wh
self.iou_threshold = iou_threshold
self.merge_detections = merge_detections
self.perform_standard_pred = perform_standard_pred
self.callback = callback
self.thread_workers = thread_workers

Expand Down Expand Up @@ -118,17 +108,9 @@ def callback(image_slice: np.ndarray) -> sv.Detections:
for future in as_completed(futures):
detections_list.append(future.result())

if self.perform_standard_pred:
detections_list.append(self.callback(image))

if self.merge_detections:
return Detections.merge(detections_list=detections_list).with_nmm(
threshold=self.iou_threshold
)
else:
return Detections.merge(detections_list=detections_list).with_nms(
threshold=self.iou_threshold
)
return Detections.merge(detections_list=detections_list).with_nms(
threshold=self.iou_threshold
)

def _run_callback(self, image, offset) -> Detections:
"""
Expand Down
Loading