Skip to content

Commit f992216

Browse files
author
Aleexo
committed
add pattern matching
1 parent cd20887 commit f992216

File tree

3 files changed

+380
-0
lines changed

3 files changed

+380
-0
lines changed

supervision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from supervision.detection.annotate import BoxAnnotator
3737
from supervision.detection.core import Detections
3838
from supervision.detection.line_zone import LineZone, LineZoneAnnotator
39+
from supervision.detection.match_pattern import MatchPattern
3940
from supervision.detection.tools.csv_sink import CSVSink
4041
from supervision.detection.tools.inference_slicer import InferenceSlicer
4142
from supervision.detection.tools.json_sink import JSONSink
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import itertools
2+
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
3+
4+
import numpy as np
5+
6+
from supervision.detection.core import Detections
7+
from supervision.geometry.core import Position
8+
9+
10+
class _Constraint:
11+
"""
12+
A constraint is a rule that a pattern must follow. It is defined
13+
by a function and its arguments. The arguments are strings that specify an object
14+
of the pattern and one of its fields.
15+
For example, this constraint tests that the objects A and B of your pattern have
16+
the same class:
17+
```python
18+
_Constraint(lambda x, y: x == y, "A.class_id", "B.class_id")
19+
```
20+
21+
!!! tip
22+
23+
You can use a value instead of a function as the criteria. It will check that
24+
the arguments are all equal to this value. For instance, this constraint tests
25+
that the object A of your pattern has a class_id equal to 1.
26+
```python
27+
_Constraint(1, "A.class_id")
28+
```
29+
This works with any number of arguments, so you can check several objects at
30+
once:
31+
```python
32+
_Constraint(1, "A.class_id", "B.class_id")
33+
```
34+
"""
35+
36+
def __init__(
37+
self, criteria: Union[Callable[..., bool], Any], arguments: List[str]
38+
) -> None:
39+
"""
40+
Args:
41+
criteria (Callable): A function that takes N arguments and returns a
42+
boolean. Criteria can also be any value, in which case the constraint
43+
checks that every argument is equal to this value.
44+
*arguments (str): A list of N strings that will be given as arguments for
45+
the criteria. The arguments should look like "name.field". The name of
46+
the object can be any name that doesn't contain a dot (`.`). The field
47+
should be one of the following:
48+
- `xyxy`, `mask`, `class_id`, `confidence`, or `tracker_id`
49+
- one of the `Position` enum strings
50+
- a field from the `data` attribute of your detections
51+
"""
52+
validate_arguments(arguments)
53+
self.arguments = arguments
54+
if callable(criteria):
55+
self.criteria = criteria
56+
else:
57+
self.criteria = lambda *args: all(equality(arg, criteria) for arg in args)
58+
59+
60+
def validate_arguments(arguments: List[str]) -> None:
61+
for argument in arguments:
62+
if argument.count(".") != 1:
63+
raise ValueError(
64+
f"Constraint argument should be `name.field`, got: '{argument}'"
65+
)
66+
67+
68+
def equality(arg1, arg2):
69+
if isinstance(arg1, np.ndarray) or isinstance(arg2, np.ndarray):
70+
return (arg1 == arg2).all()
71+
return arg1 == arg2
72+
73+
74+
class MatchPattern:
75+
"""
76+
A pattern is a set of constraints that apply to detections. You can think of
77+
patterns as regex for detections. `MatchPattern` will return all matches that
78+
satisfy all the constraints.
79+
80+
A pattern is described as named boxes organized according to rules. Each rule is
81+
given as a constraint. For instance "BoxA and BoxB should have the same class",
82+
"Boxes A and B should overlap", etc. The constraints are functions that apply to
83+
fields from the detections (the `class_id`, the `xyxy` coordinates, etc.).
84+
85+
For example, if you want to search for a cat and a dog that have the same center
86+
point you can use the following pattern:
87+
```python
88+
import cv2
89+
import supervision as sv
90+
from ultralytics import YOLO
91+
92+
image = cv2.imread(<SOURCE_IMAGE_PATH>)
93+
model = YOLO('yolov8s.pt')
94+
95+
pattern = sv.MatchPattern(
96+
[
97+
(lambda class_id: class_id == 0, ["Cat.class_id"]), # class_id for cat is 0
98+
(1, ["Dog.class_id"]), # class_id for dog is 1
99+
(
100+
lambda dog_center, cat_center: dog_center == cat_center),
101+
["Dog.CENTER", "Cat.CENTER"]
102+
),
103+
]
104+
)
105+
106+
result = model(image)[0]
107+
detections = sv.Detections.from_ultralytics(result)
108+
matches = pattern.match(detections)
109+
```
110+
111+
This will return all the matches that satisfy the constraints. The result is a list
112+
of `Detections`. A field `match_name` is added to the Detections.data to keep
113+
track of the names in your pattern.
114+
```python
115+
first_match = matches[0]
116+
first_match["match_name"] # ["Cat", "Dog"]
117+
```
118+
"""
119+
120+
def __init__(
121+
self,
122+
constraints: List[Tuple[Union[Callable[..., bool], Any], List[str]]],
123+
):
124+
"""
125+
Args:
126+
constraints (List[Tuple[Union[Callable[..., bool], Any], List[str]]]):
127+
A list of constraints. Each constraint contains a criterion and a list
128+
of arguments:
129+
- `criteria` is a function that returns a boolean value. See
130+
`_Constraint` for more information.
131+
- arguments is a list of strings. Each argument is composed of
132+
`name.field`. The field should be one of the following:
133+
- `xyxy`, `mask`, `class_id`, `confidence`, or `tracker_id`
134+
- one of the `Position` enum strings
135+
- a field from the `data` attribute of your detections
136+
"""
137+
self._constraints: List[_Constraint] = []
138+
for constraint in constraints:
139+
criteria, arguments = constraint
140+
self.add_constraint(criteria, arguments)
141+
142+
def add_constraint(
143+
self, criteria: Union[Callable[..., bool], Any], arguments: List[str]
144+
) -> None:
145+
"""
146+
Adds a constraint to the matching pattern.
147+
Args:
148+
criteria: A function that returns a boolean value or any value you want to
149+
match with the `arguments`. See `_Constraint` for more details.
150+
arguments: A list of strings. See `_Constraint` for more details.
151+
"""
152+
self._constraints.append(_Constraint(criteria, arguments))
153+
154+
def match(self, detections: Detections) -> List[Detections]:
155+
"""
156+
Matches the pattern of the constraints to the detections.
157+
158+
Args:
159+
detections (Detections): Detections to match the pattern with.
160+
161+
Returns:
162+
List[Detections]: List of detections that match the constraints. A specific
163+
field `match_name` is added to the matches to keep track of the names
164+
specified in the pattern arguments.
165+
"""
166+
combinations = self._generate_combinations(len(detections))
167+
168+
names = self._get_names_from_constraints()
169+
index = 0
170+
while index < len(combinations):
171+
combination = dict(zip(names, combinations[index]))
172+
template_kwargs = {
173+
name: detections[int(box_index)]
174+
for name, box_index in combination.items()
175+
}
176+
for constraint in self._constraints:
177+
criteria_args = [
178+
_get_argument(template_kwargs, detections, arg)
179+
for arg in constraint.arguments
180+
]
181+
if not constraint.criteria(*criteria_args):
182+
incompatible_boxes = {
183+
arg_name: combination[arg_name]
184+
for arg_name in self._get_names_from_arguments(
185+
constraint.arguments
186+
)
187+
}
188+
filter_bool = np.ones(len(combinations), dtype=bool)
189+
for name, values in incompatible_boxes.items():
190+
filter_bool &= combinations[name] == values
191+
combinations = combinations[~filter_bool]
192+
break
193+
else:
194+
index += 1
195+
196+
results: List[Detections] = []
197+
for valid_combination in combinations:
198+
indexes = list(valid_combination)
199+
matching_boxes: Detections = detections[indexes] # type: ignore
200+
matching_boxes["match_name"] = names
201+
results.append(matching_boxes)
202+
203+
return results
204+
205+
def _get_names_from_constraints(self) -> List[str]:
206+
"""
207+
Returns the object names used in the constraints.
208+
"""
209+
arguments = [
210+
arg for constraint in self._constraints for arg in constraint.arguments
211+
]
212+
return self._get_names_from_arguments(arguments)
213+
214+
def _get_names_from_arguments(self, arguments: Iterable[str]) -> List[str]:
215+
"""
216+
Returns the object names used in the arguments. Sorted and unique.
217+
"""
218+
return sorted(
219+
list({arg.split(".")[0] if "." in arg else arg for arg in arguments})
220+
)
221+
222+
def _generate_combinations(self, num_detections) -> np.ndarray:
223+
"""
224+
Generates all the possible combinations for the pattern matching.
225+
Returns an array of shape (N, M) where N is the number of combinations and M is
226+
the number of objects in the pattern. Each row corresponds to the set of indexes
227+
from detections.
228+
"""
229+
names = self._get_names_from_constraints()
230+
return np.fromiter(
231+
itertools.permutations(range(num_detections), len(names)),
232+
np.dtype([(name, int) for name in names]),
233+
)
234+
235+
236+
def _get_argument(kwargs: Dict[str, Any], detections: Detections, argument: str) -> Any:
237+
name, subfield = argument.split(".")
238+
if subfield in ["xyxy", "mask", "class_id", "confidence", "tracker_id"]:
239+
return getattr(kwargs[name], subfield)[0]
240+
if subfield in Position.list():
241+
return kwargs[name].get_anchors_coordinates(Position[subfield])[0]
242+
if subfield in detections.data:
243+
return kwargs[name][subfield][0]
244+
raise ValueError(f"Unknown field '{subfield}' for object '{name}'")
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import numpy as np
2+
import pytest
3+
4+
from supervision import Detections, MatchPattern
5+
6+
7+
@pytest.mark.parametrize(
8+
"constraints",
9+
[
10+
(
11+
[
12+
(1, ["Box1.class_id"]),
13+
(0.1, ["Box1.confidence"]),
14+
([0, 0, 15, 15], ["Box2.xyxy"]),
15+
]
16+
), # Test constraints with values
17+
(
18+
[
19+
(lambda id: id == 1, ["Box1.class_id"]),
20+
(lambda score: score == 0.1, ["Box1.confidence"]),
21+
(lambda xyxy: xyxy[3] == 15, ["Box2.xyxy"]),
22+
]
23+
), # Test constraints with functions
24+
(
25+
[
26+
(lambda id: id == 1, ["Box1.class_id"]),
27+
(lambda xyxy1, xyxy2: xyxy1[0] == xyxy2[0], ["Box1.xyxy", "Box2.xyxy"]),
28+
]
29+
), # Test constraints with multiple arguments
30+
],
31+
)
32+
def test_match_pattern(constraints):
33+
detections = Detections(
34+
xyxy=np.array(
35+
[
36+
[0, 0, 10, 10],
37+
[0, 0, 15, 15],
38+
[5, 5, 20, 20],
39+
]
40+
),
41+
confidence=np.array([0.1, 0.2, 0.3]),
42+
class_id=np.array([1, 2, 3]),
43+
)
44+
45+
expected_result = [
46+
Detections(
47+
xyxy=np.array(
48+
[
49+
[0, 0, 10, 10],
50+
[0, 0, 15, 15],
51+
]
52+
),
53+
confidence=np.array([0.1, 0.2]),
54+
class_id=np.array([1, 2]),
55+
data={"match_name": np.array(["Box1", "Box2"])},
56+
)
57+
]
58+
59+
matches = MatchPattern(constraints).match(detections)
60+
61+
assert matches == expected_result
62+
63+
64+
def test_match_pattern_with_2_results():
65+
detections = Detections(
66+
xyxy=np.array(
67+
[
68+
[0, 0, 10, 10],
69+
[0, 0, 15, 15],
70+
[5, 5, 20, 20],
71+
]
72+
),
73+
confidence=np.array([0.1, 0.2, 0.3]),
74+
class_id=np.array([1, 2, 3]),
75+
)
76+
77+
expected_result = [
78+
Detections(
79+
xyxy=np.array(
80+
[
81+
[0, 0, 10, 10],
82+
]
83+
),
84+
confidence=np.array([0.1]),
85+
class_id=np.array([1]),
86+
data={"match_name": np.array(["Box1"])},
87+
),
88+
Detections(
89+
xyxy=np.array(
90+
[
91+
[0, 0, 15, 15],
92+
]
93+
),
94+
confidence=np.array([0.2]),
95+
class_id=np.array([2]),
96+
data={"match_name": np.array(["Box1"])},
97+
),
98+
]
99+
100+
matches = MatchPattern([[lambda xyxy: xyxy[0] == 0, ["Box1.xyxy"]]]).match(
101+
detections
102+
)
103+
104+
assert matches == expected_result
105+
106+
107+
def test_add_constraint():
108+
detections = Detections(
109+
xyxy=np.array(
110+
[
111+
[0, 0, 10, 10],
112+
[0, 0, 15, 15],
113+
[5, 5, 20, 20],
114+
]
115+
),
116+
confidence=np.array([0.1, 0.2, 0.3]),
117+
class_id=np.array([1, 2, 3]),
118+
)
119+
120+
expected_result = [
121+
Detections(
122+
xyxy=np.array(
123+
[
124+
[0, 0, 10, 10],
125+
]
126+
),
127+
confidence=np.array([0.1]),
128+
class_id=np.array([1]),
129+
data={"match_name": np.array(["Box1"])},
130+
)
131+
]
132+
pattern = MatchPattern([])
133+
pattern.add_constraint(lambda id: id == 1, ["Box1.class_id"])
134+
matches = pattern.match(detections)
135+
assert matches == expected_result

0 commit comments

Comments
 (0)