Skip to content

Commit 8203ac3

Browse files
author
Felix Igelbrink
committed
automatic array conversion and check for invalid shape
1 parent 0351362 commit 8203ac3

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

numba_kdtree/kd_tree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ def _query_radius_impl(self, X, r, p=2.0, eps=0.0, return_sorted=False, return_l
382382
else:
383383
r_ = _convert_to_valid_input/(r, 1, dtype_npy).squeeze()
384384

385+
if r_.shape != (n_queries,):
386+
raise ValueError("Invalid shape for r. Must be broadcastable to the number of queries.")
387+
385388
if p < 1:
386389
raise ValueError("Only p-norms with 1<=p<=infinity permitted")
387390

tests/test_kdtree.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def test_kdtree_query_radius_array(data, kdtree, scipy_kdtree):
179179
assert len(ii[i]) == len(ii_scipy[i])
180180
assert np.all(ii[i] == ii_scipy[i]), "Not equal for i={}".format(i)
181181

182+
# invalid array shapes should not work
183+
with pytest.raises(ValueError):
184+
ii_invalid = kdtree.query_radius(data[:100], r=radii[:50], return_sorted=True)
185+
182186
num_executions = 5
183187
r_benchmark = np.linspace(0.01, 0.05, 500)
184188
runtime_kdtree_query = timeit(lambda: kdtree.query_radius(data[:500], r=r_benchmark, return_sorted=True),

0 commit comments

Comments
 (0)