Skip to content

Commit 5427046

Browse files
authored
Merge pull request #25 from mortacious/bugfix/numba_construction_failure
Fixed typing issue in numba construction
2 parents c6e1791 + 7b0890f commit 5427046

File tree

2 files changed

+40
-24
lines changed

2 files changed

+40
-24
lines changed

numba_kdtree/kd_tree.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numba.extending import overload_method
77
from .ckdtree_ctypes import ckdtree as ckdtree_ct
88
import warnings
9+
from typing import Optional, Any
910

1011

1112
__all__ = ["KDTree"]
@@ -82,49 +83,61 @@ def __new__(cls, ckdtree, root_bbox, data, idx):
8283
data,
8384
idx)
8485
@property
85-
def root_bbox(self):
86+
def root_bbox(self) -> DataArray:
8687
return _KDTree_get_root_bbox(self)
8788

8889
@property
89-
def data(self):
90+
def data(self) -> DataArray:
9091
return _KDTree_get_data(self)
9192

9293
@property
93-
def idx(self):
94+
def idx(self) -> DataArray:
9495
return _KDTree_get_idx(self)
9596

9697
@property
97-
def size(self):
98+
def size(self) -> int:
9899
return _KDTree_get_size(self)
99100

100101
@property
101-
def leafsize(self):
102+
def leafsize(self) -> int:
102103
return _KDTree_get_leafsize(self)
103104

104-
def built(self):
105+
def built(self) -> bool:
105106
return _KDTree_built(self)
106107

107-
def __del__(self):
108+
def __del__(self) -> None:
108109
try:
109110
self.free_index()
110111
except ModuleNotFoundError:
111112
# HACK: we are in the process of shutting down the interpreter so calling the external c function
112113
# might not be possible any more. For now just ignore this
113114
pass
114115

115-
def __reduce__(self):
116+
def __reduce__(self) -> Any:
116117
"""Pickle support
117118
"""
118119
args = _KDTree_reduce_args(self)
119120
return _restore_kdtree, args
120121

121-
def free_index(self):
122+
def free_index(self) -> None:
122123
_KDTree_free(self)
123124

124-
def query(self, X, k=1, p=2.0, eps=0.0, distance_upper_bound=np.inf, workers=None):
125+
def query(self,
126+
X: DataArray,
127+
k: int = 1,
128+
p: float = 2.0,
129+
eps: float = 0.0,
130+
distance_upper_bound: float = np.inf,
131+
workers: Optional[int] = None) -> tuple[DataArray, DataArray, DataArray]:
125132
return _KDTree_query(self, X, k, p, eps, distance_upper_bound, workers=workers)
126133

127-
def query_radius(self, X, r, p=2.0, eps=0.0, return_sorted=False, return_length=False, workers=None):
134+
def query_radius(self, X: DataArray,
135+
r: float,
136+
p: float = 2.0,
137+
eps: float = 0.0,
138+
return_sorted: bool = False,
139+
return_length: bool = False,
140+
workers: Optional[int] = None) -> list[DataArray]:
128141
return _KDTree_query_radius(self, X, r, p, eps, return_sorted, return_length, workers=workers)
129142

130143
structref.define_proxy(_KDTree, KDTreeType,
@@ -454,8 +467,8 @@ def _restore_kdtree_impl_impl(tree_buffer, data, root_bbox, leafsize, indices):
454467
def _restore_kdtree(tree_buffer, data, root_bbox, leafsize, indices):
455468
return _restore_kdtree_impl(tree_buffer, data, root_bbox, leafsize, indices)
456469

457-
# constructor method
458-
def KDTree(data: DataArray, leafsize: int = 10, compact: bool = False, balanced: bool = False, root_bbox=None):
470+
# constructor function
471+
def KDTree(data: DataArray, leafsize: int = 10, compact: bool = False, balanced: bool = False, root_bbox: Optional[DataArray] = None):
459472
if data.dtype == np.float32:
460473
conv_dtype = np.float32
461474
else:
@@ -469,6 +482,7 @@ def KDTree(data: DataArray, leafsize: int = 10, compact: bool = False, balanced:
469482
mins = np.amin(data, axis=0) if n_data > 0 else np.zeros(n_features, dtype=conv_dtype)
470483
maxes = np.amax(data, axis=0) if n_data > 0 else np.zeros(n_features, dtype=conv_dtype)
471484
root_bbox = np.vstack((mins, maxes))
485+
472486
root_bbox = np.ascontiguousarray(root_bbox, dtype=conv_dtype)
473487

474488
idx = np.arange(n_data, dtype=INT_TYPE)
@@ -496,22 +510,24 @@ def KDTree_impl(data, leafsize=10, compact=False, balanced=False, root_bbox=None
496510
n_data, n_features = data.shape
497511

498512
if root_bbox is None:
499-
root_bbox = np.empty((2, 3), dtype=data.dtype)
500-
root_bbox[0] = cmax
501-
root_bbox[1] = cmin
513+
# compute the bounding box
514+
root_bbox_ = np.empty((2, 3), dtype=data.dtype)
515+
root_bbox_[0] = cmax
516+
root_bbox_[1] = cmin
502517

503518
for i in range(data.shape[0]):
504519
for j in range(data.shape[1]):
505-
if data[i, j] < root_bbox[0, j]:
506-
root_bbox[0, j] = data[i, j]
507-
if data[i, j] > root_bbox[1, j]:
508-
root_bbox[1, j] = data[i, j]
509-
# compute the bounding box
510-
root_bbox = np.ascontiguousarray(root_bbox).astype(conv_dtype)
520+
if data[i, j] < root_bbox_[0, j]:
521+
root_bbox_[0, j] = data[i, j]
522+
if data[i, j] > root_bbox_[1, j]:
523+
root_bbox_[1, j] = data[i, j]
524+
else:
525+
root_bbox_ = root_bbox
526+
root_bbox__ = np.ascontiguousarray(root_bbox_).astype(conv_dtype)
511527

512528
idx = np.arange(n_data, dtype=INT_TYPE)
513529

514-
kdtree = _make_kdtree(data, root_bbox, idx, leafsize, balanced, compact)
530+
kdtree = _make_kdtree(data, root_bbox__, idx, leafsize, balanced, compact)
515531

516532
return kdtree
517533

tests/test_kdtree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def query_numba_parallel(data, kdtree, k):
243243
def test_construct_in_numba_function(data):
244244
@nb.njit(nogil=True, fastmath=True)
245245
def construct_kdtree_in_numba(data, compact=False, balanced=False):
246-
kdtree = KDTree(data, leafsize=10, compact=compact, balanced=balanced)
246+
kdtree = KDTree(data, compact=compact, balanced=balanced)
247247
return kdtree
248248

249249
num_executions = 10

0 commit comments

Comments
 (0)