Last active
September 30, 2023 13:25
Revisions
-
jakevdp revised this gist
May 19, 2015 . 3 changed files with 74 additions and 51 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -12,19 +12,21 @@ that's not obvious to me. Timings ------- These are the results on my Macbook Pro, running Python 3.4: ``` jakesmac $ python ball_tree_python.py ------------------------------------------------------- 5 neighbors of 1000 points in 3 dimensions random seed = 9742 results match: True True sklearn build: 0.00033 sec python build : 0.053 sec sklearn query: 0.004 sec python query : 1 sec jakesmac $ python ball_tree_numba.py @@ -33,9 +35,9 @@ jakesmac $ python ball_tree_numba.py random seed = 2772 results match: True True sklearn build: 0.0003 sec numba build : 0.00045 sec sklearn query: 0.0041 sec numba query : 0.041 sec ``` This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,12 +1,26 @@ import warnings import numpy as np class FakeJit(object): def __call__(self, *args, **kwargs): if kwargs: if args: raise ValueError() else: return self else: return args[0] from numba import jit as numba_jit #numba_jit = FakeJit() #---------------------------------------------------------------------- # Distance computations @numba_jit def rdist(X1, i1, X2, i2): d = 0 for k in range(X1.shape[1]): @@ -15,17 +29,17 @@ def rdist(X1, i1, X2, i2): return d @numba_jit def min_rdist(node_centroids, node_radius, i_node, X, j): d = rdist(node_centroids, i_node, X, j) return max(0, np.sqrt(d) - node_radius[i_node]) ** 2 #---------------------------------------------------------------------- # Heap for distances and neighbors def heap_create(N, k): distances = np.full((N, k), np.inf, dtype=float) indices = np.zeros((N, k), dtype=int) return distances, indices @@ -36,7 +50,7 @@ def heap_sort(distances, indices): return distances[i, j], indices[i, j] @numba_jit def heap_push(row, val, i_val, distances, indices): size = distances.shape[1] @@ -83,7 +97,7 @@ def heap_push(row, val, i_val, distances, indices): #---------------------------------------------------------------------- # Tools for building the tree @numba_jit def _partition_indices(data, idx_array, idx_start, idx_end, split_index): # Find the split dimension n_features = data.shape[1] @@ -127,8 +141,7 @@ def _partition_indices(data, idx_array, idx_start, idx_end, split_index): right = midindex - 1 @numba_jit def _recursive_build(i_node, idx_start, idx_end, data, node_centroids, node_radius, idx_array, node_idx_start, node_idx_end, node_is_leaf, @@ -152,8 +165,10 @@ def _recursive_build(i_node, idx_start, idx_end, node_idx_start[i_node] = idx_start node_idx_end[i_node] = idx_end i_child = 2 * i_node + 1 # recursively create subnodes if i_child >= n_nodes: node_is_leaf[i_node] = True if idx_end - idx_start > 2 * leaf_size: # this shouldn't happen if our memory allocation is correct. @@ -172,22 +187,21 @@ def _recursive_build(i_node, idx_start, idx_end, else: # split node and recursively construct child nodes. node_is_leaf[i_node] = False n_mid = int((idx_end + idx_start) // 2) _partition_indices(data, idx_array, idx_start, idx_end, n_mid) _recursive_build(i_child, idx_start, n_mid, data, node_centroids, node_radius, idx_array, node_idx_start, node_idx_end, node_is_leaf, n_nodes, leaf_size) _recursive_build(i_child + 1, n_mid, idx_end, data, node_centroids, node_radius, idx_array, node_idx_start, node_idx_end, node_is_leaf, n_nodes, leaf_size) #---------------------------------------------------------------------- # Tools for querying the tree @numba_jit def _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end): @@ -261,8 +275,8 @@ def __init__(self, data, leaf_size=40): # determine number of levels in the tree, and from this # the number of nodes in the tree. This results in leaf nodes # with numbers of points betweeen leaf_size and 2 * leaf_size self.n_levels = 1 + np.log2(max(1, ((self.n_samples - 1) // self.leaf_size))) self.n_nodes = int(2 ** self.n_levels) - 1 # allocate arrays for storage @@ -318,17 +332,21 @@ def query(self, X, k=1, sort_results=True): #---------------------------------------------------------------------- # Testing function def test_tree(N=1000, D=3, K=5, LS=40): from time import time from sklearn.neighbors import BallTree as skBallTree rseed = np.random.randint(10000) print("-------------------------------------------------------") print("{0} neighbors of {1} points in {2} dimensions".format(K, N, D)) print("random seed = {0}".format(rseed)) np.random.seed(rseed) X = np.random.random((N, D)) # pre-run to jit compile the code BallTree(X, leaf_size=LS).query(X, K) t0 = time() bt1 = skBallTree(X, leaf_size=LS) t1 = time() @@ -340,13 +358,14 @@ def test_tree(N=1000, D=3, K=5, LS=40): dist2, ind2 = bt2.query(X, K) t4 = time() print("results match: {0} {1}".format(np.allclose(dist1, dist2), np.allclose(ind1, ind2))) print("") print("sklearn build: {0:.2g} sec".format(t1 - t0)) print("numba build : {0:.2g} sec".format(t3 - t2)) print("") print("sklearn query: {0:.2g} sec".format(t2 - t1)) print("numba query : {0:.2g} sec".format(t4 - t3)) if __name__ == '__main__': This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,4 @@ from __future__ import division, print_function import numpy as np @@ -19,8 +20,8 @@ def __init__(self, data, leaf_size=40): # determine number of levels in the tree, and from this # the number of nodes in the tree. This results in leaf nodes # with numbers of points betweeen leaf_size and 2 * leaf_size self.n_levels = 1 + np.log2(max(1, ((self.n_samples - 1) // self.leaf_size))) self.n_nodes = int(2 ** self.n_levels) - 1 # allocate arrays for storage @@ -60,7 +61,7 @@ def _recursive_build(self, i_node, idx_start, idx_end): else: # split node and recursively construct child nodes. self.node_is_leaf[i_node] = False n_mid = int((idx_end + idx_start) // 2) _partition_indices(self.data, self.idx_array, idx_start, idx_end, n_mid) self._recursive_build(2 * i_node + 1, idx_start, n_mid) @@ -71,8 +72,8 @@ def init_node(self, i_node, idx_start, idx_end): for j in range(self.n_features): self.node_centroids[i_node, j] = 0 for i in range(idx_start, idx_end): self.node_centroids[i_node, j] += self.data[self.idx_array[i], j] self.node_centroids[i_node, j] /= (idx_end - idx_start) # determine Node radius @@ -270,9 +271,9 @@ def test_tree(N=1000, D=3, K=5, LS=40): from sklearn.neighbors import BallTree as skBallTree rseed = np.random.randint(10000) print("-------------------------------------------------------") print("{0} neighbors of {1} points in {2} dimensions".format(K, N, D)) print("random seed = {0}".format(rseed)) np.random.seed(rseed) X = np.random.random((N, D)) @@ -287,13 +288,14 @@ def test_tree(N=1000, D=3, K=5, LS=40): dist2, ind2 = bt2.query(X, K) t4 = time() print("results match: {0} {1}".format(np.allclose(dist1, dist2), np.allclose(ind1, ind2))) print("") print("sklearn build: {0:.2g} sec".format(t1 - t0)) print("python build : {0:.2g} sec".format(t3 - t2)) print("") print("sklearn query: {0:.2g} sec".format(t2 - t1)) print("python query : {0:.2g} sec".format(t4 - t3)) if __name__ == '__main__': -
jakevdp revised this gist
Mar 21, 2013 . 3 changed files with 4 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -16,7 +16,7 @@ These are the results on my Macbook Pro: ``` jakesmac $ python ball_tree_python.py ------------------------------------------------------- 5 neighbors of 1000 points in 3 dimensions random seed = 9742 results match: True True @@ -29,7 +29,7 @@ python query : 2.3 sec jakesmac $ python ball_tree_numba.py ------------------------------------------------------- 5 neighbors of 1000 points in 3 dimensions random seed = 2772 results match: True True This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -324,7 +324,7 @@ def test_tree(N=1000, D=3, K=5, LS=40): rseed = np.random.randint(10000) print "-------------------------------------------------------" print "%i neighbors of %i points in %i dimensions" % (K, N, D) print "random seed =", rseed np.random.seed(rseed) X = np.random.random((N, D)) This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -271,7 +271,7 @@ def test_tree(N=1000, D=3, K=5, LS=40): rseed = np.random.randint(10000) print "-------------------------------------------------------" print "%i neighbors of %i points in %i dimensions" % (K, N, D) print "random seed =", rseed np.random.seed(rseed) X = np.random.random((N, D)) -
jakevdp revised this gist
Mar 21, 2013 . 1 changed file with 353 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,353 @@ import warnings import numpy as np from numba import autojit, jit, void, double, int_ #---------------------------------------------------------------------- # Distance computations @jit(double(double[:,:], int_, double[:,:], int_)) def rdist(X1, i1, X2, i2): d = 0 for k in range(X1.shape[1]): tmp = (X1[i1, k] - X2[i2, k]) d += tmp * tmp return d @jit(double(double[:,:], double[:], int_, double[:,:], int_)) def min_rdist(node_centroids, node_radius, i_node, X, j): d = rdist(node_centroids, i_node, X, j) return max(0, np.sqrt(d) - node_radius[i_node]) ** 2 #---------------------------------------------------------------------- # Heap for distances and neighbors def heap_create(N, k): distances = np.zeros((N, k), dtype=float) distances.fill(np.inf) indices = np.zeros((N, k), dtype=int) return distances, indices def heap_sort(distances, indices): i = np.arange(len(distances), dtype=int)[:, None] j = np.argsort(distances, 1) return distances[i, j], indices[i, j] @jit(void(int_, double, int_, double[:,:], int_[:,:])) def heap_push(row, val, i_val, distances, indices): size = distances.shape[1] # check if val should be in heap if val > distances[row, 0]: return # insert val at position zero distances[row, 0] = val indices[row, 0] = i_val #descend the heap, swapping values until the max heap criterion is met i = 0 while True: ic1 = 2 * i + 1 ic2 = ic1 + 1 if ic1 >= size: break elif ic2 >= size: if distances[row, ic1] > val: i_swap = ic1 else: break elif distances[row, ic1] >= distances[row, ic2]: if val < distances[row, ic1]: i_swap = ic1 else: break else: if val < distances[row, ic2]: i_swap = ic2 else: break distances[row, i] = distances[row, i_swap] indices[row, i] = indices[row, i_swap] i = i_swap distances[row, i] = val indices[row, i] = i_val #---------------------------------------------------------------------- # Tools for building the tree @jit(void(double[:,:], int_[:], int_, int_, int_)) def _partition_indices(data, idx_array, idx_start, idx_end, split_index): # Find the split dimension n_features = data.shape[1] split_dim = 0 max_spread = 0 for j in range(n_features): max_val = -np.inf min_val = np.inf for i in range(idx_start, idx_end): val = data[idx_array[i], j] max_val = max(max_val, val) min_val = min(min_val, val) if max_val - min_val > max_spread: max_spread = max_val - min_val split_dim = j # Partition using the split dimension left = idx_start right = idx_end - 1 while True: midindex = left for i in range(left, right): d1 = data[idx_array[i], split_dim] d2 = data[idx_array[right], split_dim] if d1 < d2: tmp = idx_array[i] idx_array[i] = idx_array[midindex] idx_array[midindex] = tmp midindex += 1 tmp = idx_array[midindex] idx_array[midindex] = idx_array[right] idx_array[right] = tmp if midindex == split_index: break elif midindex < split_index: left = midindex + 1 else: right = midindex - 1 @jit(void(int_, int_, int_, double[:,:], double[:,:], double[:], int_[:], int_[:], int_[:], int_[:], int_, int_)) def _recursive_build(i_node, idx_start, idx_end, data, node_centroids, node_radius, idx_array, node_idx_start, node_idx_end, node_is_leaf, n_nodes, leaf_size): # determine Node centroid for j in range(data.shape[1]): node_centroids[i_node, j] = 0 for i in range(idx_start, idx_end): node_centroids[i_node, j] += data[idx_array[i], j] node_centroids[i_node, j] /= (idx_end - idx_start) # determine Node radius sq_radius = 0.0 for i in range(idx_start, idx_end): sq_dist = rdist(node_centroids, i_node, data, idx_array[i]) if sq_dist > sq_radius: sq_radius = sq_dist # set node properties node_radius[i_node] = np.sqrt(sq_radius) node_idx_start[i_node] = idx_start node_idx_end[i_node] = idx_end # recursively create subnodes if 2 * i_node + 1 >= n_nodes: node_is_leaf[i_node] = True if idx_end - idx_start > 2 * leaf_size: # this shouldn't happen if our memory allocation is correct. # We'll proactively prevent memory errors, but raise a # warning saying we're doing so. warnings.warn("Internal: memory layout is flawed: " "not enough nodes allocated") pass elif idx_end - idx_start < 2: # again, this shouldn't happen if our memory allocation is correct. warnings.warn("Internal: memory layout is flawed: " "too many nodes allocated") node_is_leaf[i_node] = True else: # split node and recursively construct child nodes. node_is_leaf[i_node] = False n_mid = int((idx_end + idx_start) / 2) _partition_indices(data, idx_array, idx_start, idx_end, n_mid) _recursive_build(2 * i_node + 1, idx_start, n_mid, data, node_centroids, node_radius, idx_array, node_idx_start, node_idx_end, node_is_leaf, n_nodes, leaf_size) _recursive_build(2 * i_node + 2, n_mid, idx_end, data, node_centroids, node_radius, idx_array, node_idx_start, node_idx_end, node_is_leaf, n_nodes, leaf_size) #---------------------------------------------------------------------- # Tools for querying the tree @jit(void(int_, double[:,:], int_, double[:,:], int_[:,:], double, double[:,:], int_[:], double[:,:], double[:], int_[:], int_[:], int_[:])) def _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end): #------------------------------------------------------------ # Case 1: query point is outside node radius: # trim it from the query if sq_dist_LB > heap_distances[i_pt, 0]: pass #------------------------------------------------------------ # Case 2: this is a leaf node. Update set of nearby points elif node_is_leaf[i_node]: for i in range(node_idx_start[i_node], node_idx_end[i_node]): dist_pt = rdist(data, idx_array[i], X, i_pt) if dist_pt < heap_distances[i_pt, 0]: heap_push(i_pt, dist_pt, idx_array[i], heap_distances, heap_indices) #------------------------------------------------------------ # Case 3: Node is not a leaf. Recursively query subnodes # starting with the closest else: i1 = 2 * i_node + 1 i2 = i1 + 1 sq_dist_LB_1 = min_rdist(node_centroids, node_radius, i1, X, i_pt) sq_dist_LB_2 = min_rdist(node_centroids, node_radius, i2, X, i_pt) # recursively query subnodes if sq_dist_LB_1 <= sq_dist_LB_2: _query_recursive(i1, X, i_pt, heap_distances, heap_indices, sq_dist_LB_1, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end) _query_recursive(i2, X, i_pt, heap_distances, heap_indices, sq_dist_LB_2, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end) else: _query_recursive(i2, X, i_pt, heap_distances, heap_indices, sq_dist_LB_2, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end) _query_recursive(i1, X, i_pt, heap_distances, heap_indices, sq_dist_LB_1, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end) #---------------------------------------------------------------------- # The Ball Tree object class BallTree(object): def __init__(self, data, leaf_size=40): self.data = data self.leaf_size = leaf_size # validate data if self.data.size == 0: raise ValueError("X is an empty array") if leaf_size < 1: raise ValueError("leaf_size must be greater than or equal to 1") self.n_samples = self.data.shape[0] self.n_features = self.data.shape[1] # determine number of levels in the tree, and from this # the number of nodes in the tree. This results in leaf nodes # with numbers of points betweeen leaf_size and 2 * leaf_size self.n_levels = (np.log2(max(1, (self.n_samples - 1) / self.leaf_size)) + 1) self.n_nodes = int(2 ** self.n_levels) - 1 # allocate arrays for storage self.idx_array = np.arange(self.n_samples, dtype=int) self.node_radius = np.zeros(self.n_nodes, dtype=float) self.node_idx_start = np.zeros(self.n_nodes, dtype=int) self.node_idx_end = np.zeros(self.n_nodes, dtype=int) self.node_is_leaf = np.zeros(self.n_nodes, dtype=int) self.node_centroids = np.zeros((self.n_nodes, self.n_features), dtype=float) # Allocate tree-specific data from TreeBase _recursive_build(0, 0, self.n_samples, self.data, self.node_centroids, self.node_radius, self.idx_array, self.node_idx_start, self.node_idx_end, self.node_is_leaf, self.n_nodes, self.leaf_size) def query(self, X, k=1, sort_results=True): X = np.asarray(X, dtype=float) if X.shape[-1] != self.n_features: raise ValueError("query data dimension must " "match training data dimension") if self.data.shape[0] < k: raise ValueError("k must be less than or equal " "to the number of training points") # flatten X, and save original shape information Xshape = X.shape X = X.reshape((-1, self.data.shape[1])) # initialize heap for neighbors heap_distances, heap_indices = heap_create(X.shape[0], k) for i in range(X.shape[0]): sq_dist_LB = min_rdist(self.node_centroids, self.node_radius, 0, X, i) _query_recursive(0, X, i, heap_distances, heap_indices, sq_dist_LB, self.data, self.idx_array, self.node_centroids, self.node_radius, self.node_is_leaf, self.node_idx_start, self.node_idx_end) distances, indices = heap_sort(heap_distances, heap_indices) distances = np.sqrt(distances) # deflatten results return (distances.reshape(Xshape[:-1] + (k,)), indices.reshape(Xshape[:-1] + (k,))) #---------------------------------------------------------------------- # Testing function def test_tree(N=1000, D=3, K=5, LS=40): from time import time from sklearn.neighbors import BallTree as skBallTree rseed = np.random.randint(10000) print "-------------------------------------------------------" print "%i neighbors of %i points in %i dimensions" % (N, D, K) print "random seed =", rseed np.random.seed(rseed) X = np.random.random((N, D)) t0 = time() bt1 = skBallTree(X, leaf_size=LS) t1 = time() dist1, ind1 = bt1.query(X, K) t2 = time() bt2 = BallTree(X, leaf_size=LS) t3 = time() dist2, ind2 = bt2.query(X, K) t4 = time() print "results match:", np.allclose(dist1, dist2), np.allclose(ind1, ind2) print print "sklearn build: %.2g sec" % (t1 - t0) print "numba build : %.2g sec" % (t3 - t2) print print "sklearn query: %.2g sec" % (t2 - t1) print "numba query : %.2g sec" % (t4 - t3) if __name__ == '__main__': test_tree() -
jakevdp created this gist
Mar 21, 2013 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,41 @@ Numba Ball Tree ================ This is a quick attempt at writing a ball tree for nearest neighbor searches using numba. I've included a pure python version, and a version with numba jit decorators. Because class support in numba is not yet complete, all the code is factored out to stand-alone functions in the numba version. The resulting code produced by numba is about ~10 times slower than the cython ball tree in scikit-learn. My guess is that part of this stems from lack of inlining in numba, while the rest is due to some sort of overhead that's not obvious to me. Timings ------- These are the results on my Macbook Pro: ``` jakesmac $ python ball_tree_python.py ------------------------------------------------------- 1000 neighbors of 3 points in 5 dimensions random seed = 9742 results match: True True sklearn build: 0.00058 sec python build : 0.11 sec sklearn query: 0.0049 sec python query : 2.3 sec jakesmac $ python ball_tree_numba.py ------------------------------------------------------- 1000 neighbors of 3 points in 5 dimensions random seed = 2772 results match: True True sklearn build: 0.00032 sec numba build : 0.0039 sec sklearn query: 0.0048 sec numba query : 0.029 sec ``` This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,300 @@ import numpy as np class BallTree(object): def __init__(self, data, leaf_size=40): self.data = data self.leaf_size = leaf_size # validate data if self.data.size == 0: raise ValueError("X is an empty array") if leaf_size < 1: raise ValueError("leaf_size must be greater than or equal to 1") self.n_samples = self.data.shape[0] self.n_features = self.data.shape[1] # determine number of levels in the tree, and from this # the number of nodes in the tree. This results in leaf nodes # with numbers of points betweeen leaf_size and 2 * leaf_size self.n_levels = (np.log2(max(1, (self.n_samples - 1) / self.leaf_size)) + 1) self.n_nodes = int(2 ** self.n_levels) - 1 # allocate arrays for storage self.idx_array = np.arange(self.n_samples, dtype=int) self.node_radius = np.zeros(self.n_nodes, dtype=float) self.node_idx_start = np.zeros(self.n_nodes, dtype=int) self.node_idx_end = np.zeros(self.n_nodes, dtype=int) self.node_is_leaf = np.zeros(self.n_nodes, dtype=int) self.node_centroids = np.zeros((self.n_nodes, self.n_features), dtype=float) # Allocate tree-specific data from TreeBase self._recursive_build(0, 0, self.n_samples) def _recursive_build(self, i_node, idx_start, idx_end): # initialize node data self.init_node(i_node, idx_start, idx_end) if 2 * i_node + 1 >= self.n_nodes: self.node_is_leaf[i_node] = True if idx_end - idx_start > 2 * self.leaf_size: # this shouldn't happen if our memory allocation is correct # we'll proactively prevent memory errors, but raise a # warning saying we're doing so. import warnings warnings.warn("Internal: memory layout is flawed: " "not enough nodes allocated") elif idx_end - idx_start < 2: # again, this shouldn't happen if our memory allocation # is correct. Raise a warning. import warnings warnings.warn("Internal: memory layout is flawed: " "too many nodes allocated") self.node_is_leaf[i_node] = True else: # split node and recursively construct child nodes. self.node_is_leaf[i_node] = False n_mid = int((idx_end + idx_start) / 2) _partition_indices(self.data, self.idx_array, idx_start, idx_end, n_mid) self._recursive_build(2 * i_node + 1, idx_start, n_mid) self._recursive_build(2 * i_node + 2, n_mid, idx_end) def init_node(self, i_node, idx_start, idx_end): # determine Node centroid for j in range(self.n_features): self.node_centroids[i_node, j] = 0 for i in range(idx_start, idx_end): self.node_centroids[i_node, j] +=\ self.data[self.idx_array[i], j] self.node_centroids[i_node, j] /= (idx_end - idx_start) # determine Node radius sq_radius = 0 for i in range(idx_start, idx_end): sq_dist = self.rdist(self.node_centroids, i_node, self.data, self.idx_array[i]) sq_radius = max(sq_radius, sq_dist) self.node_radius[i_node] = np.sqrt(sq_radius) self.node_idx_start[i_node] = idx_start self.node_idx_end[i_node] = idx_end nbrhd = self.data[self.idx_array[idx_start:idx_end]] def rdist(self, X1, i1, X2, i2): d = 0 for k in range(self.n_features): tmp = (X1[i1, k] - X2[i2, k]) d += tmp * tmp return d def min_rdist(self, i_node, X, j): d = self.rdist(self.node_centroids, i_node, X, j) return max(0, np.sqrt(d) - self.node_radius[i_node]) ** 2 def query(self, X, k=1, sort_results=True): X = np.asarray(X, dtype=float) if X.shape[-1] != self.n_features: raise ValueError("query data dimension must " "match training data dimension") if self.data.shape[0] < k: raise ValueError("k must be less than or equal " "to the number of training points") # flatten X, and save original shape information Xshape = X.shape X = X.reshape((-1, self.data.shape[1])) # initialize heap for neighbors heap = NeighborsHeap(X.shape[0], k) for i in range(X.shape[0]): sq_dist_LB = self.min_rdist(0, X, i) self._query_recursive(0, X, i, heap, sq_dist_LB) distances, indices = heap.get_arrays(sort=sort_results) distances = np.sqrt(distances) # deflatten results return (distances.reshape(Xshape[:-1] + (k,)), indices.reshape(Xshape[:-1] + (k,))) def _query_recursive(self, i_node, X, i_pt, heap, sq_dist_LB): #------------------------------------------------------------ # Case 1: query point is outside node radius: # trim it from the query if sq_dist_LB > heap.largest(i_pt): pass #------------------------------------------------------------ # Case 2: this is a leaf node. Update set of nearby points elif self.node_is_leaf[i_node]: for i in range(self.node_idx_start[i_node], self.node_idx_end[i_node]): dist_pt = self.rdist(self.data, self.idx_array[i], X, i_pt) if dist_pt < heap.largest(i_pt): heap.push(i_pt, dist_pt, self.idx_array[i]) #------------------------------------------------------------ # Case 3: Node is not a leaf. Recursively query subnodes # starting with the closest else: i1 = 2 * i_node + 1 i2 = i1 + 1 sq_dist_LB_1 = self.min_rdist(i1, X, i_pt) sq_dist_LB_2 = self.min_rdist(i2, X, i_pt) # recursively query subnodes if sq_dist_LB_1 <= sq_dist_LB_2: self._query_recursive(i1, X, i_pt, heap, sq_dist_LB_1) self._query_recursive(i2, X, i_pt, heap, sq_dist_LB_2) else: self._query_recursive(i2, X, i_pt, heap, sq_dist_LB_2) self._query_recursive(i1, X, i_pt, heap, sq_dist_LB_1) def _partition_indices(data, idx_array, idx_start, idx_end, split_index): # Find the split dimension n_features = data.shape[1] split_dim = 0 max_spread = 0 for j in range(n_features): max_val = -np.inf min_val = np.inf for i in range(idx_start, idx_end): val = data[idx_array[i], j] max_val = max(max_val, val) min_val = min(min_val, val) if max_val - min_val > max_spread: max_spread = max_val - min_val split_dim = j # Partition using the split dimension left = idx_start right = idx_end - 1 while True: midindex = left for i in range(left, right): d1 = data[idx_array[i], split_dim] d2 = data[idx_array[right], split_dim] if d1 < d2: tmp = idx_array[i] idx_array[i] = idx_array[midindex] idx_array[midindex] = tmp midindex += 1 tmp = idx_array[midindex] idx_array[midindex] = idx_array[right] idx_array[right] = tmp if midindex == split_index: break elif midindex < split_index: left = midindex + 1 else: right = midindex - 1 class NeighborsHeap: def __init__(self, n_pts, n_nbrs): self.distances = np.zeros((n_pts, n_nbrs), dtype=float) + np.inf self.indices = np.zeros((n_pts, n_nbrs), dtype=int) def get_arrays(self, sort=True): if sort: i = np.arange(len(self.distances), dtype=int)[:, None] j = np.argsort(self.distances, 1) return self.distances[i, j], self.indices[i, j] else: return self.distances, self.indices def largest(self, row): return self.distances[row, 0] def push(self, row, val, i_val): size = self.distances.shape[1] # check if val should be in heap if val > self.distances[row, 0]: return # insert val at position zero self.distances[row, 0] = val self.indices[row, 0] = i_val #descend the heap, swapping values until the max heap criterion is met i = 0 while True: ic1 = 2 * i + 1 ic2 = ic1 + 1 if ic1 >= size: break elif ic2 >= size: if self.distances[row, ic1] > val: i_swap = ic1 else: break elif self.distances[row, ic1] >= self.distances[row, ic2]: if val < self.distances[row, ic1]: i_swap = ic1 else: break else: if val < self.distances[row, ic2]: i_swap = ic2 else: break self.distances[row, i] = self.distances[row, i_swap] self.indices[row, i] = self.indices[row, i_swap] i = i_swap self.distances[row, i] = val self.indices[row, i] = i_val def test_tree(N=1000, D=3, K=5, LS=40): from time import time from sklearn.neighbors import BallTree as skBallTree rseed = np.random.randint(10000) print "-------------------------------------------------------" print "%i neighbors of %i points in %i dimensions" % (N, D, K) print "random seed =", rseed np.random.seed(rseed) X = np.random.random((N, D)) t0 = time() bt1 = skBallTree(X, leaf_size=LS) t1 = time() dist1, ind1 = bt1.query(X, K) t2 = time() bt2 = BallTree(X, leaf_size=LS) t3 = time() dist2, ind2 = bt2.query(X, K) t4 = time() print "results match:", np.allclose(dist1, dist2), np.allclose(ind1, ind2) print print "sklearn build: %.2g sec" % (t1 - t0) print "python build : %.2g sec" % (t3 - t2) print print "sklearn query: %.2g sec" % (t2 - t1) print "python query : %.2g sec" % (t4 - t3) if __name__ == '__main__': test_tree()