Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c78ae33
feat: 🚀 Added Non-Maximum Merging to Detections
Oct 13, 2023
57b12e6
Added __setitem__ to Detections and refactored the object prediction …
Oct 18, 2023
9f22273
Added standard full image inference after sliced inference to increas…
Oct 18, 2023
6f47046
Refactored merging of Detection attributes to better work with np.nda…
Oct 18, 2023
5f0dcc2
Merge branch 'develop' into add_nmm_to_detections to resolve conflicts
Apr 9, 2024
166a8da
Implement Feedback
Apr 11, 2024
b159873
Merge remote-tracking branch 'upstream/develop' into add_nmm_to_detec…
May 6, 2024
d7e52be
NMM: Add None-checks, fix area normalization, style
May 6, 2024
bee3252
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 6, 2024
97c4071
NMM: Move detections merge into Detections class.
May 6, 2024
204669b
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 6, 2024
2eb0c7c
Merge remote-tracking branch 'upstream/develop' into add_nmm_to_detec…
LinasKo May 14, 2024
c3b77d0
Rename, remove functions, unit-test & change `merge_object_detection_…
May 14, 2024
8014e88
Test box_non_max_merge
May 14, 2024
26bafec
Test box_non_max_merge, rename threshold,to __init__
May 15, 2024
d2d50fb
renamed bbox -> xyxy
May 15, 2024
2d740bd
fix: merge_object_detection_pair
May 15, 2024
145b5fe
Rename to batch_box_non_max_merge to box_non_max_merge_batch
May 15, 2024
6c40935
box_non_max_merge: use our functions to compute iou
May 15, 2024
53f345e
Minor renaming
May 15, 2024
0e2eec0
Revert np.bool comparisons with `is`
May 15, 2024
559ef90
Simplify box_non_max_merge
May 15, 2024
f8f3647
Removed suprplus NMM code for 20% speedup
May 15, 2024
9024396
Add npt.NDarray[x] types, remove resolution_wh default val
May 17, 2024
6fbca83
Address review comments, simplify merge
May 23, 2024
db1b473
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 23, 2024
0721bc2
Remove _set_at_index
May 23, 2024
530e1d0
Address comments
May 27, 2024
2ee9e08
Renamed to group_overlapping_boxes
May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
NMM: Add None-checks, fix area normalization, style
  • Loading branch information
Linas Kondrackis committed May 6, 2024
commit d7e52bee264fb1b3b5c47a3f27b5eb67deae86a6
181 changes: 132 additions & 49 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@
def _merge_object_detection_pair(det1: Detections, det2: Detections) -> Detections:
"""
Merges two Detections object into a single Detections object.
Assumes each Detections contains exactly one object.

A `winning` detection is determined based on the confidence score of the two
input detections. This winning detection is then used to specify which `class_id`,
`tracker_id`, and `data` to include in the merged Detections object.
input detections. This winning detection is then used to specify which
`class_id`, `tracker_id`, and `data` to include in the merged Detections object.

The resulting `confidence` of the merged object is calculated by the weighted
contribution of each detection to the merged object.
The bounding boxes and masks of the two input detections are merged into a single
bounding box and mask, respectively.
The bounding boxes and masks of the two input detections are merged into a
single bounding box and mask, respectively.

Args:
det1 (Detections):
Expand All @@ -47,45 +49,79 @@ def _merge_object_detection_pair(det1: Detections, det2: Detections) -> Detectio

Returns:
Detections: A new Detections object, with merged attributes.

Raises:
ValueError: If the input Detections objects do not have exactly 1 detected
object.

Example:
```python
import cv2
import supervision as sv
from inference import get_model

image = cv2.imread(<SOURCE_IMAGE_PATH>)
model = get_model(model_id="yolov8s-640")

result = model.infer(image)[0]
detections = sv.Detections.from_inference(result)

merged_detections = merge_object_detection_pair(
detections[0], detections[1])
```
"""
assert (
len(det1) == len(det2) == 1
), "Both Detections should have exactly 1 detected object."
winning_det = det1 if det1.confidence.item() > det2.confidence.item() else det2
if len(det1) != 1 or len(det2) != 1:
raise ValueError(
"Both Detections should have exactly 1 detected object.")

if det2.confidence is None:
winning_det = det1
elif det1.confidence is None:
winning_det = det2
elif det1.confidence[0] >= det2.confidence[0]:
winning_det = det1
else:
winning_det = det2

area_det1 = (det1.xyxy[0][2] - det1.xyxy[0][0]) * (
det1.xyxy[0][3] - det1.xyxy[0][1]
)
area_det2 = (det2.xyxy[0][2] - det2.xyxy[0][0]) * (
det2.xyxy[0][3] - det2.xyxy[0][1]
)

merged_x1, merged_y1 = np.minimum(det1.xyxy[0][:2], det2.xyxy[0][:2])
merged_x2, merged_y2 = np.maximum(det1.xyxy[0][2:], det2.xyxy[0][2:])
merged_area = (merged_x2 - merged_x1) * (merged_y2 - merged_y1)

merged_conf = (
area_det1 * det1.confidence.item() + area_det2 * det2.confidence.item()
) / merged_area
merged_bbox = [np.concatenate([merged_x1, merged_y1, merged_x2, merged_y2])]
merged_class_id = winning_det.class_id.item()
merged_tracker_id = None
merged_mask = None
merged_data = None

if det1.mask and det2.mask:
merged_xy = np.array([[merged_x1, merged_y1, merged_x2, merged_y2]])

winning_class_id = winning_det.class_id

if det1.confidence is None or det2.confidence is None:
merged_confidence = None
else:
merged_confidence = (
area_det1 * det1.confidence[0] + area_det2 * det2.confidence[0]
) / (area_det1 + area_det2)
merged_confidence = np.array([merged_confidence])

merged_mask = None
if det1.mask is not None and det2.mask is not None:
merged_mask = np.logical_or(det1.mask, det2.mask)
if det1.tracker_id and det2.tracker_id:
merged_tracker_id = winning_det.tracker_id.item()

winning_tracker_id = winning_det.tracker_id

winning_data = None
if det1.data and det2.data:
merged_data = winning_det.data
winning_data = winning_det.data

return Detections(
xyxy=merged_bbox,
xyxy=merged_xy,
mask=merged_mask,
confidence=merged_conf,
class_id=merged_class_id,
tracker_id=merged_tracker_id,
data=merged_data,
confidence=merged_confidence,
class_id=winning_class_id,
tracker_id=winning_tracker_id,
data=winning_data,
)


Expand Down Expand Up @@ -260,7 +296,8 @@ def from_yolov5(cls, yolov5_results) -> Detections:
detections = sv.Detections.from_yolov5(result)
```
"""
yolov5_detections_predictions = yolov5_results.pred[0].cpu().cpu().numpy()
yolov5_detections_predictions = yolov5_results.pred[0].cpu(
).cpu().numpy()

return cls(
xyxy=yolov5_detections_predictions[:, :4],
Expand Down Expand Up @@ -307,7 +344,8 @@ def from_ultralytics(cls, ultralytics_results) -> Detections:

if "obb" in ultralytics_results and ultralytics_results.obb is not None:
class_id = ultralytics_results.obb.cls.cpu().numpy().astype(int)
class_names = np.array([ultralytics_results.names[i] for i in class_id])
class_names = np.array(
[ultralytics_results.names[i] for i in class_id])
oriented_box_coordinates = ultralytics_results.obb.xyxyxyxy.cpu().numpy()
return cls(
xyxy=ultralytics_results.obb.xyxy.cpu().numpy(),
Expand All @@ -323,7 +361,8 @@ def from_ultralytics(cls, ultralytics_results) -> Detections:
)

class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int)
class_names = np.array([ultralytics_results.names[i] for i in class_id])
class_names = np.array([ultralytics_results.names[i]
for i in class_id])
return cls(
xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(),
confidence=ultralytics_results.boxes.conf.cpu().numpy(),
Expand Down Expand Up @@ -411,7 +450,8 @@ def from_tensorflow(
return cls(
xyxy=boxes,
confidence=tensorflow_results["detection_scores"][0].numpy(),
class_id=tensorflow_results["detection_classes"][0].numpy().astype(int),
class_id=tensorflow_results["detection_classes"][0].numpy().astype(
int),
)

@classmethod
Expand Down Expand Up @@ -448,7 +488,8 @@ def from_deepsparse(cls, deepsparse_results) -> Detections:
return cls(
xyxy=np.array(deepsparse_results.boxes[0]),
confidence=np.array(deepsparse_results.scores[0]),
class_id=np.array(deepsparse_results.labels[0]).astype(float).astype(int),
class_id=np.array(deepsparse_results.labels[0]).astype(
float).astype(int),
)

@classmethod
Expand Down Expand Up @@ -535,24 +576,29 @@ class names. If provided, the resulting Detections object will contain
Class names values can be accessed using `detections["class_name"]`.
""" # noqa: E501 // docs

class_ids = transformers_results["labels"].cpu().detach().numpy().astype(int)
class_ids = transformers_results["labels"].cpu(
).detach().numpy().astype(int)
data = {}
if id2label is not None:
class_names = np.array([id2label[class_id] for class_id in class_ids])
class_names = np.array([id2label[class_id]
for class_id in class_ids])
data[CLASS_NAME_DATA_FIELD] = class_names
if "boxes" in transformers_results:
return cls(
xyxy=transformers_results["boxes"].cpu().detach().numpy(),
confidence=transformers_results["scores"].cpu().detach().numpy(),
confidence=transformers_results["scores"].cpu(
).detach().numpy(),
class_id=class_ids,
data=data,
)
elif "masks" in transformers_results:
masks = transformers_results["masks"].cpu().detach().numpy().astype(bool)
masks = transformers_results["masks"].cpu(
).detach().numpy().astype(bool)
return cls(
xyxy=mask_to_xyxy(masks),
mask=masks,
confidence=transformers_results["scores"].cpu().detach().numpy(),
confidence=transformers_results["scores"].cpu(
).detach().numpy(),
class_id=class_ids,
data=data,
)
Expand Down Expand Up @@ -595,7 +641,8 @@ class IDs, and confidences of the predictions.
"""

return cls(
xyxy=detectron2_results["instances"].pred_boxes.tensor.cpu().numpy(),
xyxy=detectron2_results["instances"].pred_boxes.tensor.cpu(
).numpy(),
confidence=detectron2_results["instances"].scores.cpu().numpy(),
class_id=detectron2_results["instances"]
.pred_classes.cpu()
Expand Down Expand Up @@ -638,7 +685,8 @@ def from_inference(cls, roboflow_result: Union[dict, Any]) -> Detections:
Class names values can be accessed using `detections["class_name"]`.
"""
with suppress(AttributeError):
roboflow_result = roboflow_result.dict(exclude_none=True, by_alias=True)
roboflow_result = roboflow_result.dict(
exclude_none=True, by_alias=True)
xyxy, confidence, class_id, masks, trackers, data = process_roboflow_result(
roboflow_result=roboflow_result
)
Expand Down Expand Up @@ -730,7 +778,8 @@ def from_sam(cls, sam_result: List[dict]) -> Detections:
)

xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
mask = np.array([mask["segmentation"] for mask in sorted_generated_masks])
mask = np.array([mask["segmentation"]
for mask in sorted_generated_masks])

if np.asarray(xywh).shape[0] == 0:
return cls.empty()
Expand Down Expand Up @@ -957,7 +1006,8 @@ def stack_or_none(name: str):
if all(d.__getattribute__(name) is None for d in detections_list):
return None
if any(d.__getattribute__(name) is None for d in detections_list):
raise ValueError(f"All or none of the '{name}' fields must be None")
raise ValueError(
f"All or none of the '{name}' fields must be None")
return (
np.vstack([d.__getattribute__(name) for d in detections_list])
if name == "mask"
Expand Down Expand Up @@ -1128,6 +1178,34 @@ def __setitem__(self, key: str, value: Union[np.ndarray, List]):

self.data[key] = value

def _set_at_index(self, index: int, other: Detections):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't placing this code as part of the setitem method makes more sense? The flow below feels quite natural to me.

detections_1 = sv.Detections(...)
detections_2 = sv.Detections(...)
detections_1[0] = detections_2[0]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__setitam__

detections_2[0]
detections_2["class_name"]

__getitam__

detections_2[0]
detections_2[1:3]
detections_2[[1, 2, 3]]
detections_2[[False, True, False]]
detections_2["class_name"]

Copy link
Contributor

@LinasKo LinasKo May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set_at_index was not required. I removed it entirely, and did not add any logic to __setitem__.

"""
Set detection values (xyxy, confidence, ...) at a specified index
to those of another Detections object, at index 0.

Args:
index (int): The index in current detection, where values
will be set.
other (Detections): Detections object with exactly one element
to set the values from.

Raises:
ValueError: If `other` is not made of exactly one element.
"""
if len(other) != 1:
raise ValueError(
"Detection to set from must have exactly one element.")

self.xyxy[index] = other.xyxy[0]
if self.mask is not None and other.mask is not None:
self.mask[index] = other.mask[0]
if self.confidence is not None and other.confidence is not None:
self.confidence[index] = other.confidence[0]
if self.class_id is not None and other.class_id is not None:
self.class_id[index] = other.class_id[0]
if self.tracker_id is not None and other.tracker_id is not None:
self.tracker_id[index] = other.tracker_id[0]

@property
def area(self) -> np.ndarray:
"""
Expand Down Expand Up @@ -1188,7 +1266,8 @@ def with_nms(
), "Detections confidence must be given for NMS to be executed."

if class_agnostic:
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
predictions = np.hstack(
(self.xyxy, self.confidence.reshape(-1, 1)))
else:
assert self.class_id is not None, (
"Detections class_id must be given for NMS to be executed. If you"
Expand Down Expand Up @@ -1244,9 +1323,14 @@ def with_nmm(
), "Detections confidence must be given for NMM to be executed."

if class_agnostic:
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
predictions = np.hstack(
(self.xyxy, self.confidence.reshape(-1, 1)))
keep_to_merge_list = non_max_merge(predictions, threshold)
else:
assert self.class_id is not None, (
"Detections class_id must be given for NMS to be executed. If you"
" intended to perform class agnostic NMM set class_agnostic=True."
)
predictions = np.hstack(
(
self.xyxy,
Expand All @@ -1257,16 +1341,15 @@ def with_nmm(
keep_to_merge_list = batch_non_max_merge(predictions, threshold)

result = []

for keep_ind, merge_ind_list in keep_to_merge_list.items():
for merge_ind in merge_ind_list:
if (
box_iou_batch(self[keep_ind].xyxy, self[merge_ind].xyxy).item()
> threshold
):
self[keep_ind] = _merge_object_detection_pair(
box_iou = box_iou_batch(
self[keep_ind].xyxy, self[merge_ind].xyxy)[0]
if box_iou > threshold:
merged_detection = _merge_object_detection_pair(
self[keep_ind], self[merge_ind]
)
self._set_at_index(keep_ind, merged_detection)
result.append(self[keep_ind])

return Detections.merge(result)