import warnings import numpy as np from numba import jit as numba_jit import numba #---------------------------------------------------------------------- # Distance computations @numba.jit(nopython=True) def rdist(X1, i1, X2, i2): d = 0 for k in range(X1.shape[1]): tmp = (X1[i1, k] - X2[i2, k]) d += tmp * tmp return d @numba.jit(nopython=True) def min_rdist(node_centroids, node_radius, i_node, X, j): d = rdist(node_centroids, i_node, X, j) return np.square(max(0, np.sqrt(d) - node_radius[i_node])) #---------------------------------------------------------------------- # Heap for distances and neighbors @numba.jit(nopython=True) def heap_create(N, k): distances = np.full((N, k), np.finfo(np.float64).max) indices = np.zeros((N, k), dtype=np.int64) return distances, indices def heap_sort(distances, indices): i = np.arange(len(distances), dtype=int)[:, None] j = np.argsort(distances, 1) return distances[i, j], indices[i, j] @numba.jit(nopython=True) def heap_push(row, val, i_val, distances, indices): size = distances.shape[1] # check if val should be in heap if val > distances[row, 0]: return # insert val at position zero distances[row, 0] = val indices[row, 0] = i_val #descend the heap, swapping values until the max heap criterion is met i = 0 while True: ic1 = 2 * i + 1 ic2 = ic1 + 1 if ic1 >= size: break elif ic2 >= size: if distances[row, ic1] > val: i_swap = ic1 else: break elif distances[row, ic1] >= distances[row, ic2]: if val < distances[row, ic1]: i_swap = ic1 else: break else: if val < distances[row, ic2]: i_swap = ic2 else: break distances[row, i] = distances[row, i_swap] indices[row, i] = indices[row, i_swap] i = i_swap distances[row, i] = val indices[row, i] = i_val #---------------------------------------------------------------------- # Tools for building the tree @numba.jit(nopython=True) def _partition_indices(data, idx_array, idx_start, idx_end, split_index): # Find the split dimension n_features = data.shape[1] split_dim = 0 max_spread = 0 for j in range(n_features): max_val = -np.inf min_val = np.inf for i in range(idx_start, idx_end): val = data[idx_array[i], j] max_val = max(max_val, val) min_val = min(min_val, val) if max_val - min_val > max_spread: max_spread = max_val - min_val split_dim = j # Partition using the split dimension left = idx_start right = idx_end - 1 while True: midindex = left for i in range(left, right): d1 = data[idx_array[i], split_dim] d2 = data[idx_array[right], split_dim] if d1 < d2: tmp = idx_array[i] idx_array[i] = idx_array[midindex] idx_array[midindex] = tmp midindex += 1 tmp = idx_array[midindex] idx_array[midindex] = idx_array[right] idx_array[right] = tmp if midindex == split_index: break elif midindex < split_index: left = midindex + 1 else: right = midindex - 1 @numba.jit(nopython=True) def _recursive_build(i_node, idx_start, idx_end, data, node_centroids, node_radius, idx_array, node_idx_start, node_idx_end, node_is_leaf, n_nodes, leaf_size): # determine Node centroid for j in range(data.shape[1]): node_centroids[i_node, j] = 0 for i in range(idx_start, idx_end): node_centroids[i_node, j] += data[idx_array[i], j] node_centroids[i_node, j] /= (idx_end - idx_start) # determine Node radius sq_radius = 0.0 for i in range(idx_start, idx_end): sq_dist = rdist(node_centroids, i_node, data, idx_array[i]) if sq_dist > sq_radius: sq_radius = sq_dist # set node properties node_radius[i_node] = np.sqrt(sq_radius) node_idx_start[i_node] = idx_start node_idx_end[i_node] = idx_end i_child = 2 * i_node + 1 # recursively create subnodes if i_child >= n_nodes: node_is_leaf[i_node] = True if idx_end - idx_start > 2 * leaf_size: # this shouldn't happen if our memory allocation is correct. # We'll proactively prevent memory errors, but raise a # warning saying we're doing so. #warnings.warn("Internal: memory layout is flawed: " # "not enough nodes allocated") pass elif idx_end - idx_start < 2: # again, this shouldn't happen if our memory allocation is correct. #warnings.warn("Internal: memory layout is flawed: " # "too many nodes allocated") node_is_leaf[i_node] = True else: # split node and recursively construct child nodes. node_is_leaf[i_node] = False n_mid = int((idx_end + idx_start) // 2) _partition_indices(data, idx_array, idx_start, idx_end, n_mid) _recursive_build(i_child, idx_start, n_mid, data, node_centroids, node_radius, idx_array, node_idx_start, node_idx_end, node_is_leaf, n_nodes, leaf_size) _recursive_build(i_child + 1, n_mid, idx_end, data, node_centroids, node_radius, idx_array, node_idx_start, node_idx_end, node_is_leaf, n_nodes, leaf_size) #---------------------------------------------------------------------- # Tools for querying the tree @numba.jit(nopython=True) def _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end): #------------------------------------------------------------ # Case 1: query point is outside node radius: # trim it from the query if sq_dist_LB > heap_distances[i_pt, 0]: pass #------------------------------------------------------------ # Case 2: this is a leaf node. Update set of nearby points elif node_is_leaf[i_node]: for i in range(node_idx_start[i_node], node_idx_end[i_node]): dist_pt = rdist(data, idx_array[i], X, i_pt) if dist_pt < heap_distances[i_pt, 0]: heap_push(i_pt, dist_pt, idx_array[i], heap_distances, heap_indices) #------------------------------------------------------------ # Case 3: Node is not a leaf. Recursively query subnodes # starting with the closest else: i1 = 2 * i_node + 1 i2 = i1 + 1 sq_dist_LB_1 = min_rdist(node_centroids, node_radius, i1, X, i_pt) sq_dist_LB_2 = min_rdist(node_centroids, node_radius, i2, X, i_pt) # recursively query subnodes if sq_dist_LB_1 <= sq_dist_LB_2: _query_recursive(i1, X, i_pt, heap_distances, heap_indices, sq_dist_LB_1, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end) _query_recursive(i2, X, i_pt, heap_distances, heap_indices, sq_dist_LB_2, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end) else: _query_recursive(i2, X, i_pt, heap_distances, heap_indices, sq_dist_LB_2, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end) _query_recursive(i1, X, i_pt, heap_distances, heap_indices, sq_dist_LB_1, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end) @numba.jit(nopython=True, parallel=True) def _query_parallel(i_node, X, heap_distances, heap_indices, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end): for i_pt in numba.prange(X.shape[0]): sq_dist_LB = min_rdist(node_centroids, node_radius, i_node, X, i_pt) _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB, data, idx_array, node_centroids, node_radius, node_is_leaf, node_idx_start, node_idx_end) #---------------------------------------------------------------------- # The Ball Tree object class BallTree(object): def __init__(self, data, leaf_size=40): self.data = data self.leaf_size = leaf_size # validate data if self.data.size == 0: raise ValueError("X is an empty array") if leaf_size < 1: raise ValueError("leaf_size must be greater than or equal to 1") self.n_samples = self.data.shape[0] self.n_features = self.data.shape[1] # determine number of levels in the tree, and from this # the number of nodes in the tree. This results in leaf nodes # with numbers of points betweeen leaf_size and 2 * leaf_size self.n_levels = 1 + np.log2(max(1, ((self.n_samples - 1) // self.leaf_size))) self.n_nodes = int(2 ** self.n_levels) - 1 # allocate arrays for storage self.idx_array = np.arange(self.n_samples, dtype=int) self.node_radius = np.zeros(self.n_nodes, dtype=float) self.node_idx_start = np.zeros(self.n_nodes, dtype=int) self.node_idx_end = np.zeros(self.n_nodes, dtype=int) self.node_is_leaf = np.zeros(self.n_nodes, dtype=int) self.node_centroids = np.zeros((self.n_nodes, self.n_features), dtype=float) # Allocate tree-specific data from TreeBase _recursive_build(0, 0, self.n_samples, self.data, self.node_centroids, self.node_radius, self.idx_array, self.node_idx_start, self.node_idx_end, self.node_is_leaf, self.n_nodes, self.leaf_size) def query(self, X, k=1, sort_results=True): X = np.asarray(X, dtype=float) if X.shape[-1] != self.n_features: raise ValueError("query data dimension must " "match training data dimension") if self.data.shape[0] < k: raise ValueError("k must be less than or equal " "to the number of training points") # flatten X, and save original shape information Xshape = X.shape X = X.reshape((-1, self.data.shape[1])) # initialize heap for neighbors heap_distances, heap_indices = heap_create(X.shape[0], k) #for i in range(X.shape[0]): # sq_dist_LB = min_rdist(self.node_centroids, # self.node_radius, # 0, X, i) # _query_recursive(0, X, i, heap_distances, heap_indices, sq_dist_LB, # self.data, self.idx_array, self.node_centroids, # self.node_radius, self.node_is_leaf, # self.node_idx_start, self.node_idx_end) _query_parallel(0, X, heap_distances, heap_indices, self.data, self.idx_array, self.node_centroids, self.node_radius, self.node_is_leaf, self.node_idx_start, self.node_idx_end) distances, indices = heap_sort(heap_distances, heap_indices) distances = np.sqrt(distances) # deflatten results return (distances.reshape(Xshape[:-1] + (k,)), indices.reshape(Xshape[:-1] + (k,))) #---------------------------------------------------------------------- # Testing function def test_tree(N=1000, D=3, K=5, LS=40): from time import time from sklearn.neighbors import BallTree as skBallTree print("-------------------------------------------------------") print("Numba version: " + numba.__version__) rseed = np.random.randint(10000) print("-------------------------------------------------------") print("{0} neighbors of {1} points in {2} dimensions".format(K, N, D)) print("random seed = {0}".format(rseed)) np.random.seed(rseed) X = np.random.random((N, D)) # pre-run to jit compile the code BallTree(X, leaf_size=LS).query(X, K) t0 = time() bt1 = skBallTree(X, leaf_size=LS) t1 = time() dist1, ind1 = bt1.query(X, K) t2 = time() bt2 = BallTree(X, leaf_size=LS) t3 = time() dist2, ind2 = bt2.query(X, K) t4 = time() print("results match: {0} {1}".format(np.allclose(dist1, dist2), np.allclose(ind1, ind2))) print("") print("sklearn build: {0:.3g} sec".format(t1 - t0)) print("numba build : {0:.3g} sec".format(t3 - t2)) print("") print("sklearn query: {0:.3g} sec".format(t2 - t1)) print("numba query : {0:.3g} sec".format(t4 - t3)) if __name__ == '__main__': test_tree()