diff --git a/supervision/__init__.py b/supervision/__init__.py index 16de484a3..816142b90 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -44,6 +44,7 @@ from supervision.detection.tools.smoother import DetectionsSmoother from supervision.detection.utils import ( box_iou_batch, + box_non_max_merge, box_non_max_suppression, calculate_masks_centroids, clip_boxes, diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 0ba9e4f42..f85d403d7 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -8,6 +8,8 @@ from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES from supervision.detection.utils import ( + box_iou_batch, + box_non_max_merge, box_non_max_suppression, calculate_masks_centroids, extract_ultralytics_masks, @@ -1150,3 +1152,193 @@ def with_nms( ) return self[indices] + + def with_nmm( + self, threshold: float = 0.5, class_agnostic: bool = False + ) -> Detections: + """ + Perform non-maximum merging on the current set of object detections. + + Args: + threshold (float, optional): The intersection-over-union threshold + to use for non-maximum merging. Defaults to 0.5. + class_agnostic (bool, optional): Whether to perform class-agnostic + non-maximum merging. If True, the class_id of each detection + will be ignored. Defaults to False. + + Returns: + Detections: A new Detections object containing the subset of detections + after non-maximum merging. + + Raises: + AssertionError: If `confidence` is None or `class_id` is None and + class_agnostic is False. + """ + if len(self) == 0: + return self + + assert ( + self.confidence is not None + ), "Detections confidence must be given for NMM to be executed." + + if class_agnostic: + predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1))) + else: + assert self.class_id is not None, ( + "Detections class_id must be given for NMM to be executed. If you" + " intended to perform class agnostic NMM set class_agnostic=True." + ) + predictions = np.hstack( + ( + self.xyxy, + self.confidence.reshape(-1, 1), + self.class_id.reshape(-1, 1), + ) + ) + + merge_groups = box_non_max_merge( + predictions=predictions, iou_threshold=threshold + ) + + result = [] + for merge_group in merge_groups: + unmerged_detections = [self[i] for i in merge_group] + merged_detections = merge_inner_detections_objects( + unmerged_detections, threshold + ) + result.append(merged_detections) + + return Detections.merge(result) + + +def merge_inner_detection_object_pair( + detections_1: Detections, detections_2: Detections +) -> Detections: + """ + Merges two Detections object into a single Detections object. + Assumes each Detections contains exactly one object. + + A `winning` detection is determined based on the confidence score of the two + input detections. This winning detection is then used to specify which + `class_id`, `tracker_id`, and `data` to include in the merged Detections object. + + The resulting `confidence` of the merged object is calculated by the weighted + contribution of ea detection to the merged object. + The bounding boxes and masks of the two input detections are merged into a + single bounding box and mask, respectively. + + Args: + detections_1 (Detections): + The first Detections object + detections_2 (Detections): + The second Detections object + + Returns: + Detections: A new Detections object, with merged attributes. + + Raises: + ValueError: If the input Detections objects do not have exactly 1 detected + object. + + Example: + ```python + import cv2 + import supervision as sv + from inference import get_model + + image = cv2.imread() + model = get_model(model_id="yolov8s-640") + + result = model.infer(image)[0] + detections = sv.Detections.from_inference(result) + + merged_detections = merge_object_detection_pair( + detections[0], detections[1]) + ``` + """ + if len(detections_1) != 1 or len(detections_2) != 1: + raise ValueError("Both Detections should have exactly 1 detected object.") + + validate_fields_both_defined_or_none(detections_1, detections_2) + + xyxy_1 = detections_1.xyxy[0] + xyxy_2 = detections_2.xyxy[0] + if detections_1.confidence is None and detections_2.confidence is None: + merged_confidence = None + else: + detection_1_area = (xyxy_1[2] - xyxy_1[0]) * (xyxy_1[3] - xyxy_1[1]) + detections_2_area = (xyxy_2[2] - xyxy_2[0]) * (xyxy_2[3] - xyxy_2[1]) + merged_confidence = ( + detection_1_area * detections_1.confidence[0] + + detections_2_area * detections_2.confidence[0] + ) / (detection_1_area + detections_2_area) + merged_confidence = np.array([merged_confidence]) + + merged_x1, merged_y1 = np.minimum(xyxy_1[:2], xyxy_2[:2]) + merged_x2, merged_y2 = np.maximum(xyxy_1[2:], xyxy_2[2:]) + merged_xyxy = np.array([[merged_x1, merged_y1, merged_x2, merged_y2]]) + + if detections_1.mask is None and detections_2.mask is None: + merged_mask = None + else: + merged_mask = np.logical_or(detections_1.mask, detections_2.mask) + + if detections_1.confidence is None and detections_2.confidence is None: + winning_detection = detections_1 + elif detections_1.confidence[0] >= detections_2.confidence[0]: + winning_detection = detections_1 + else: + winning_detection = detections_2 + + return Detections( + xyxy=merged_xyxy, + mask=merged_mask, + confidence=merged_confidence, + class_id=winning_detection.class_id, + tracker_id=winning_detection.tracker_id, + data=winning_detection.data, + ) + + +def merge_inner_detections_objects( + detections: List[Detections], threshold=0.5 +) -> Detections: + """ + Given N detections each of length 1 (exactly one object inside), combine them into a + single detection object of length 1. The contained inner object will be the merged + result of all the input detections. + + For example, this lets you merge N boxes into one big box, N masks into one mask, + etc. + """ + detections_1 = detections[0] + for detections_2 in detections[1:]: + box_iou = box_iou_batch(detections_1.xyxy, detections_2.xyxy)[0] + if box_iou < threshold: + break + detections_1 = merge_inner_detection_object_pair(detections_1, detections_2) + return detections_1 + + +def validate_fields_both_defined_or_none( + detections_1: Detections, detections_2: Detections +) -> None: + """ + Verify that for each optional field in the Detections, both instances either have + the field set to None or both have it set to non-None values. + + `data` field is ignored. + + Raises: + ValueError: If one field is None and the other is not, for any of the fields. + """ + attributes = ["mask", "confidence", "class_id", "tracker_id"] + for attribute in attributes: + value_1 = getattr(detections_1, attribute) + value_2 = getattr(detections_2, attribute) + + if (value_1 is None) != (value_2 is None): + raise ValueError( + f"Field '{attribute}' should be consistently None or not None in both " + "Detections." + ) diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index 9232089e5..74726995e 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -3,6 +3,7 @@ import cv2 import numpy as np +import numpy.typing as npt from supervision.config import CLASS_NAME_DATA_FIELD @@ -274,6 +275,91 @@ def box_non_max_suppression( return keep[sort_index.argsort()] +def group_overlapping_boxes( + predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5 +) -> List[List[int]]: + """ + Apply greedy version of non-maximum merging to avoid detecting too many + overlapping bounding boxes for a given object. + + Args: + predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` containing + the bounding boxes coordinates in format `[x1, y1, x2, y2]` + and the confidence scores. + iou_threshold (float, optional): The intersection-over-union threshold + to use for non-maximum suppression. Defaults to 0.5. + + Returns: + List[List[int]]: Groups of prediction indices be merged. + Each group may have 1 or more elements. + """ + merge_groups: List[List[int]] = [] + + scores = predictions[:, 4] + order = scores.argsort() + + while len(order) > 0: + idx = int(order[-1]) + + order = order[:-1] + if len(order) == 0: + merge_groups.append([idx]) + break + + merge_candidate = np.expand_dims(predictions[idx], axis=0) + ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4]) + ious = ious.flatten() + + above_threshold = ious >= iou_threshold + merge_group = [idx] + np.flip(order[above_threshold]).tolist() + merge_groups.append(merge_group) + order = order[~above_threshold] + return merge_groups + + +def box_non_max_merge( + predictions: npt.NDArray[np.float64], + iou_threshold: float = 0.5, +) -> List[List[int]]: + """ + Apply greedy version of non-maximum merging per category to avoid detecting + too many overlapping bounding boxes for a given object. + + Args: + predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` or `(n, 6)` + containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`, + the confidence scores and class_ids. Omit class_id column to allow + detections of different classes to be merged. + iou_threshold (float, optional): The intersection-over-union threshold + to use for non-maximum suppression. Defaults to 0.5. + + Returns: + List[List[int]]: Groups of prediction indices be merged. + Each group may have 1 or more elements. + """ + if predictions.shape[1] == 5: + return group_overlapping_boxes(predictions, iou_threshold) + + category_ids = predictions[:, 5] + merge_groups = [] + for category_id in np.unique(category_ids): + curr_indices = np.where(category_ids == category_id)[0] + merge_class_groups = group_overlapping_boxes( + predictions[curr_indices], iou_threshold + ) + + for merge_class_group in merge_class_groups: + merge_groups.append(curr_indices[merge_class_group].tolist()) + + for merge_group in merge_groups: + if len(merge_group) == 0: + raise ValueError( + f"Empty group detected when non-max-merging " + f"detections: {merge_groups}" + ) + return merge_groups + + def clip_boxes(xyxy: np.ndarray, resolution_wh: Tuple[int, int]) -> np.ndarray: """ Clips bounding boxes coordinates to fit within the frame resolution. @@ -346,7 +432,7 @@ def mask_to_xyxy(masks: np.ndarray) -> np.ndarray: `(x_min, y_min, x_max, y_max)` for each mask """ n = masks.shape[0] - bboxes = np.zeros((n, 4), dtype=int) + xyxy = np.zeros((n, 4), dtype=int) for i, mask in enumerate(masks): rows, cols = np.where(mask) @@ -354,9 +440,9 @@ def mask_to_xyxy(masks: np.ndarray) -> np.ndarray: if len(rows) > 0 and len(cols) > 0: x_min, x_max = np.min(cols), np.max(cols) y_min, y_max = np.min(rows), np.max(rows) - bboxes[i, :] = [x_min, y_min, x_max, y_max] + xyxy[i, :] = [x_min, y_min, x_max, y_max] - return bboxes + return xyxy def mask_to_polygons(mask: np.ndarray) -> List[np.ndarray]: @@ -592,16 +678,18 @@ def process_roboflow_result( return xyxy, confidence, class_id, masks, tracker_id, data -def move_boxes(xyxy: np.ndarray, offset: np.ndarray) -> np.ndarray: +def move_boxes( + xyxy: npt.NDArray[np.float64], offset: npt.NDArray[np.int32] +) -> npt.NDArray[np.float64]: """ Parameters: - xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes - coordinates in format `[x1, y1, x2, y2]` + xyxy (npt.NDArray[np.float64]): An array of shape `(n, 4)` containing the + bounding boxes coordinates in format `[x1, y1, x2, y2]` offset (np.array): An array of shape `(2,)` containing offset values in format is `[dx, dy]`. Returns: - np.ndarray: Repositioned bounding boxes. + npt.NDArray[np.float64]: Repositioned bounding boxes. Example: ```python @@ -622,24 +710,25 @@ def move_boxes(xyxy: np.ndarray, offset: np.ndarray) -> np.ndarray: def move_masks( - masks: np.ndarray, - offset: np.ndarray, - resolution_wh: Tuple[int, int] = None, -) -> np.ndarray: + masks: npt.NDArray[np.bool_], + offset: npt.NDArray[np.int32], + resolution_wh: Tuple[int, int], +) -> npt.NDArray[np.bool_]: """ Offset the masks in an array by the specified (x, y) amount. Args: - masks (np.ndarray): A 3D array of binary masks corresponding to the predictions. - Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the - dimensions of each mask. - offset (np.ndarray): An array of shape `(2,)` containing non-negative int values - `[dx, dy]`. + masks (npt.NDArray[np.bool_]): A 3D array of binary masks corresponding to the + predictions. Shape: `(N, H, W)`, where N is the number of predictions, and + H, W are the dimensions of each mask. + offset (npt.NDArray[np.int32]): An array of shape `(2,)` containing non-negative + int values `[dx, dy]`. resolution_wh (Tuple[int, int]): The width and height of the desired mask resolution. Returns: - (np.ndarray) repositioned masks, optionally padded to the specified shape. + (npt.NDArray[np.bool_]) repositioned masks, optionally padded to the specified + shape. """ if offset[0] < 0 or offset[1] < 0: @@ -655,19 +744,21 @@ def move_masks( return mask_array -def scale_boxes(xyxy: np.ndarray, factor: float) -> np.ndarray: +def scale_boxes( + xyxy: npt.NDArray[np.float64], factor: float +) -> npt.NDArray[np.float64]: """ Scale the dimensions of bounding boxes. Parameters: - xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes - coordinates in format `[x1, y1, x2, y2]` + xyxy (npt.NDArray[np.float64]): An array of shape `(n, 4)` containing the + bounding boxes coordinates in format `[x1, y1, x2, y2]` factor (float): A float value representing the factor by which the box dimensions are scaled. A factor greater than 1 enlarges the boxes, while a factor less than 1 shrinks them. Returns: - np.ndarray: Scaled bounding boxes. + npt.NDArray[np.float64]: Scaled bounding boxes. Example: ```python @@ -735,19 +826,19 @@ def is_data_equal(data_a: Dict[str, np.ndarray], data_b: Dict[str, np.ndarray]) def merge_data( - data_list: List[Dict[str, Union[np.ndarray, List]]], -) -> Dict[str, Union[np.ndarray, List]]: + data_list: List[Dict[str, Union[npt.NDArray[np.generic], List]]], +) -> Dict[str, Union[npt.NDArray[np.generic], List]]: """ Merges the data payloads of a list of Detections instances. Args: data_list: The data payloads of the Detections instances. Each data payload is a dictionary with the same keys, and the values are either lists or - np.ndarray. + npt.NDArray[np.generic]. Returns: A single data payload containing the merged data, preserving the original data - types (list or np.ndarray). + types (list or npt.NDArray[np.generic]). Raises: ValueError: If data values within a single object have different lengths or if diff --git a/test/detection/test_core.py b/test/detection/test_core.py index 12f3de281..af1d58762 100644 --- a/test/detection/test_core.py +++ b/test/detection/test_core.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from supervision.detection.core import Detections +from supervision.detection.core import Detections, merge_inner_detection_object_pair from supervision.geometry.core import Position PREDICTIONS = np.array( @@ -421,3 +421,172 @@ def test_equal( detections_a: Detections, detections_b: Detections, expected_result: bool ) -> None: assert (detections_a == detections_b) == expected_result + + +@pytest.mark.parametrize( + "detection_1, detection_2, expected_result, exception", + [ + ( + mock_detections( + xyxy=[[10, 10, 30, 30]], + ), + mock_detections( + xyxy=[[10, 10, 30, 30]], + ), + mock_detections( + xyxy=[[10, 10, 30, 30]], + ), + DoesNotRaise(), + ), # Merge with self + ( + mock_detections( + xyxy=[[10, 10, 30, 30]], + ), + Detections.empty(), + None, + pytest.raises(ValueError), + ), # merge with empty: error + ( + mock_detections( + xyxy=[[10, 10, 30, 30]], + ), + mock_detections( + xyxy=[[10, 10, 30, 30], [40, 40, 60, 60]], + ), + None, + pytest.raises(ValueError), + ), # merge with 2+ objects: error + ( + mock_detections( + xyxy=[[10, 10, 30, 30]], + confidence=[0.1], + class_id=[1], + mask=[np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=bool)], + tracker_id=[1], + data={"key_1": [1]}, + ), + mock_detections( + xyxy=[[20, 20, 40, 40]], + confidence=[0.1], + class_id=[2], + mask=[np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=bool)], + tracker_id=[2], + data={"key_2": [2]}, + ), + mock_detections( + xyxy=[[10, 10, 40, 40]], + confidence=[0.1], + class_id=[1], + mask=[np.array([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=bool)], + tracker_id=[1], + data={"key_1": [1]}, + ), + DoesNotRaise(), + ), # Same confidence - merge box & mask, tie-break to detection_1 + ( + mock_detections( + xyxy=[[0, 0, 20, 20]], + confidence=[0.1], + class_id=[1], + mask=[np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=bool)], + tracker_id=[1], + data={"key_1": [1]}, + ), + mock_detections( + xyxy=[[10, 10, 50, 50]], + confidence=[0.2], + class_id=[2], + mask=[np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=bool)], + tracker_id=[2], + data={"key_2": [2]}, + ), + mock_detections( + xyxy=[[0, 0, 50, 50]], + confidence=[(1 * 0.1 + 4 * 0.2) / 5], + class_id=[2], + mask=[np.array([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=bool)], + tracker_id=[2], + data={"key_2": [2]}, + ), + DoesNotRaise(), + ), # Different confidence, different area + ( + mock_detections( + xyxy=[[10, 10, 30, 30]], + confidence=None, + class_id=[1], + mask=[np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=bool)], + tracker_id=[1], + data={"key_1": [1]}, + ), + mock_detections( + xyxy=[[20, 20, 40, 40]], + confidence=None, + class_id=[2], + mask=[np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=bool)], + tracker_id=[2], + data={"key_2": [2]}, + ), + mock_detections( + xyxy=[[10, 10, 40, 40]], + confidence=None, + class_id=[1], + mask=[np.array([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=bool)], + tracker_id=[1], + data={"key_1": [1]}, + ), + DoesNotRaise(), + ), # No confidence at all + ( + mock_detections( + xyxy=[[0, 0, 20, 20]], + confidence=None, + ), + mock_detections( + xyxy=[[10, 10, 30, 30]], + confidence=[0.2], + ), + None, + pytest.raises(ValueError), + ), # confidence: None + [x] + ( + mock_detections( + xyxy=[[0, 0, 20, 20]], + mask=[np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=bool)], + ), + mock_detections( + xyxy=[[10, 10, 30, 30]], + mask=None, + ), + None, + pytest.raises(ValueError), + ), # mask: None + [x] + ( + mock_detections(xyxy=[[0, 0, 20, 20]], tracker_id=[1]), + mock_detections( + xyxy=[[10, 10, 30, 30]], + tracker_id=None, + ), + None, + pytest.raises(ValueError), + ), # tracker_id: None + [] + ( + mock_detections(xyxy=[[0, 0, 20, 20]], class_id=[1]), + mock_detections( + xyxy=[[10, 10, 30, 30]], + class_id=None, + ), + None, + pytest.raises(ValueError), + ), # class_id: None + [] + ], +) +def test_merge_inner_detection_object_pair( + detection_1: Detections, + detection_2: Detections, + expected_result: Optional[Detections], + exception: Exception, +): + with exception: + result = merge_inner_detection_object_pair(detection_1, detection_2) + assert result == expected_result diff --git a/test/detection/test_utils.py b/test/detection/test_utils.py index 097c5c6e5..b62faa619 100644 --- a/test/detection/test_utils.py +++ b/test/detection/test_utils.py @@ -11,6 +11,7 @@ clip_boxes, filter_polygons_by_area, get_data_item, + group_overlapping_boxes, mask_non_max_suppression, merge_data, move_boxes, @@ -127,6 +128,133 @@ def test_box_non_max_suppression( assert np.array_equal(result, expected_result) +@pytest.mark.parametrize( + "predictions, iou_threshold, expected_result, exception", + [ + ( + np.empty(shape=(0, 5), dtype=float), + 0.5, + [], + DoesNotRaise(), + ), + ( + np.array([[0, 0, 10, 10, 1.0]]), + 0.5, + [[0]], + DoesNotRaise(), + ), + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]), + 0.5, + [[1, 0]], + DoesNotRaise(), + ), # High overlap, tie-break to second det + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 0.99]]), + 0.5, + [[0, 1]], + DoesNotRaise(), + ), # High overlap, merge to high confidence + ( + np.array([[0, 0, 10, 10, 0.99], [0, 0, 9, 9, 1.0]]), + 0.5, + [[1, 0]], + DoesNotRaise(), + ), # (test symmetry) High overlap, merge to high confidence + ( + np.array([[0, 0, 10, 10, 0.90], [0, 0, 9, 9, 1.0]]), + 0.5, + [[1, 0]], + DoesNotRaise(), + ), # (test symmetry) High overlap, merge to high confidence + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]), + 1.0, + [[1], [0]], + DoesNotRaise(), + ), # High IOU required + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]), + 0.0, + [[1, 0]], + DoesNotRaise(), + ), # No IOU required + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 5, 5, 0.9]]), + 0.25, + [[0, 1]], + DoesNotRaise(), + ), # Below IOU requirement + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 5, 5, 0.9]]), + 0.26, + [[0], [1]], + DoesNotRaise(), + ), # Above IOU requirement + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0], [0, 0, 8, 8, 1.0]]), + 0.5, + [[2, 1, 0]], + DoesNotRaise(), + ), # 3 boxes + ( + np.array( + [ + [0, 0, 10, 10, 1.0], + [0, 0, 9, 9, 1.0], + [5, 5, 10, 10, 1.0], + [6, 6, 10, 10, 1.0], + [9, 9, 10, 10, 1.0], + ] + ), + 0.5, + [[4], [3, 2], [1, 0]], + DoesNotRaise(), + ), # 5 boxes, 2 merges, 1 separate + ( + np.array( + [ + [0, 0, 2, 1, 1.0], + [1, 0, 3, 1, 1.0], + [2, 0, 4, 1, 1.0], + [3, 0, 5, 1, 1.0], + [4, 0, 6, 1, 1.0], + ] + ), + 0.33, + [[4, 3], [2, 1], [0]], + DoesNotRaise(), + ), # sequential merge, half overlap + ( + np.array( + [ + [0, 0, 2, 1, 0.9], + [1, 0, 3, 1, 0.9], + [2, 0, 4, 1, 1.0], + [3, 0, 5, 1, 0.9], + [4, 0, 6, 1, 0.9], + ] + ), + 0.33, + [[2, 3, 1], [4], [0]], + DoesNotRaise(), + ), # confidence + ], +) +def test_group_overlapping_boxes( + predictions: np.ndarray, + iou_threshold: float, + expected_result: List[List[int]], + exception: Exception, +) -> None: + with exception: + result = group_overlapping_boxes( + predictions=predictions, iou_threshold=iou_threshold + ) + + assert result == expected_result + + @pytest.mark.parametrize( "predictions, masks, iou_threshold, expected_result, exception", [