66from numba .extending import overload_method
77from .ckdtree_ctypes import ckdtree as ckdtree_ct
88import warnings
9+ from typing import Optional , Any
910
1011
1112__all__ = ["KDTree" ]
@@ -82,49 +83,61 @@ def __new__(cls, ckdtree, root_bbox, data, idx):
8283 data ,
8384 idx )
8485 @property
85- def root_bbox (self ):
86+ def root_bbox (self ) -> DataArray :
8687 return _KDTree_get_root_bbox (self )
8788
8889 @property
89- def data (self ):
90+ def data (self ) -> DataArray :
9091 return _KDTree_get_data (self )
9192
9293 @property
93- def idx (self ):
94+ def idx (self ) -> DataArray :
9495 return _KDTree_get_idx (self )
9596
9697 @property
97- def size (self ):
98+ def size (self ) -> int :
9899 return _KDTree_get_size (self )
99100
100101 @property
101- def leafsize (self ):
102+ def leafsize (self ) -> int :
102103 return _KDTree_get_leafsize (self )
103104
104- def built (self ):
105+ def built (self ) -> bool :
105106 return _KDTree_built (self )
106107
107- def __del__ (self ):
108+ def __del__ (self ) -> None :
108109 try :
109110 self .free_index ()
110111 except ModuleNotFoundError :
111112 # HACK: we are in the process of shutting down the interpreter so calling the external c function
112113 # might not be possible any more. For now just ignore this
113114 pass
114115
115- def __reduce__ (self ):
116+ def __reduce__ (self ) -> Any :
116117 """Pickle support
117118 """
118119 args = _KDTree_reduce_args (self )
119120 return _restore_kdtree , args
120121
121- def free_index (self ):
122+ def free_index (self ) -> None :
122123 _KDTree_free (self )
123124
124- def query (self , X , k = 1 , p = 2.0 , eps = 0.0 , distance_upper_bound = np .inf , workers = None ):
125+ def query (self ,
126+ X : DataArray ,
127+ k : int = 1 ,
128+ p : float = 2.0 ,
129+ eps : float = 0.0 ,
130+ distance_upper_bound : float = np .inf ,
131+ workers : Optional [int ] = None ) -> tuple [DataArray , DataArray , DataArray ]:
125132 return _KDTree_query (self , X , k , p , eps , distance_upper_bound , workers = workers )
126133
127- def query_radius (self , X , r , p = 2.0 , eps = 0.0 , return_sorted = False , return_length = False , workers = None ):
134+ def query_radius (self , X : DataArray ,
135+ r : float ,
136+ p : float = 2.0 ,
137+ eps : float = 0.0 ,
138+ return_sorted : bool = False ,
139+ return_length : bool = False ,
140+ workers : Optional [int ] = None ) -> list [DataArray ]:
128141 return _KDTree_query_radius (self , X , r , p , eps , return_sorted , return_length , workers = workers )
129142
130143structref .define_proxy (_KDTree , KDTreeType ,
@@ -454,8 +467,8 @@ def _restore_kdtree_impl_impl(tree_buffer, data, root_bbox, leafsize, indices):
454467def _restore_kdtree (tree_buffer , data , root_bbox , leafsize , indices ):
455468 return _restore_kdtree_impl (tree_buffer , data , root_bbox , leafsize , indices )
456469
457- # constructor method
458- def KDTree (data : DataArray , leafsize : int = 10 , compact : bool = False , balanced : bool = False , root_bbox = None ):
470+ # constructor function
471+ def KDTree (data : DataArray , leafsize : int = 10 , compact : bool = False , balanced : bool = False , root_bbox : Optional [ DataArray ] = None ):
459472 if data .dtype == np .float32 :
460473 conv_dtype = np .float32
461474 else :
@@ -469,6 +482,7 @@ def KDTree(data: DataArray, leafsize: int = 10, compact: bool = False, balanced:
469482 mins = np .amin (data , axis = 0 ) if n_data > 0 else np .zeros (n_features , dtype = conv_dtype )
470483 maxes = np .amax (data , axis = 0 ) if n_data > 0 else np .zeros (n_features , dtype = conv_dtype )
471484 root_bbox = np .vstack ((mins , maxes ))
485+
472486 root_bbox = np .ascontiguousarray (root_bbox , dtype = conv_dtype )
473487
474488 idx = np .arange (n_data , dtype = INT_TYPE )
@@ -496,22 +510,24 @@ def KDTree_impl(data, leafsize=10, compact=False, balanced=False, root_bbox=None
496510 n_data , n_features = data .shape
497511
498512 if root_bbox is None :
499- root_bbox = np .empty ((2 , 3 ), dtype = data .dtype )
500- root_bbox [0 ] = cmax
501- root_bbox [1 ] = cmin
513+ # compute the bounding box
514+ root_bbox_ = np .empty ((2 , 3 ), dtype = data .dtype )
515+ root_bbox_ [0 ] = cmax
516+ root_bbox_ [1 ] = cmin
502517
503518 for i in range (data .shape [0 ]):
504519 for j in range (data .shape [1 ]):
505- if data [i , j ] < root_bbox [0 , j ]:
506- root_bbox [0 , j ] = data [i , j ]
507- if data [i , j ] > root_bbox [1 , j ]:
508- root_bbox [1 , j ] = data [i , j ]
509- # compute the bounding box
510- root_bbox = np .ascontiguousarray (root_bbox ).astype (conv_dtype )
520+ if data [i , j ] < root_bbox_ [0 , j ]:
521+ root_bbox_ [0 , j ] = data [i , j ]
522+ if data [i , j ] > root_bbox_ [1 , j ]:
523+ root_bbox_ [1 , j ] = data [i , j ]
524+ else :
525+ root_bbox_ = root_bbox
526+ root_bbox__ = np .ascontiguousarray (root_bbox_ ).astype (conv_dtype )
511527
512528 idx = np .arange (n_data , dtype = INT_TYPE )
513529
514- kdtree = _make_kdtree (data , root_bbox , idx , leafsize , balanced , compact )
530+ kdtree = _make_kdtree (data , root_bbox__ , idx , leafsize , balanced , compact )
515531
516532 return kdtree
517533
0 commit comments