Skip to content

Instantly share code, notes, and snippets.

@jakevdp
Last active September 30, 2023 13:25

Revisions

  1. jakevdp revised this gist May 19, 2015. 3 changed files with 74 additions and 51 deletions.
    20 changes: 11 additions & 9 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -12,19 +12,21 @@ that's not obvious to me.

    Timings
    -------
    These are the results on my Macbook Pro:
    These are the results on my Macbook Pro, running Python 3.4:


    ```
    jakesmac $ python ball_tree_python.py
    -------------------------------------------------------
    5 neighbors of 1000 points in 3 dimensions
    random seed = 9742
    results match: True True
    sklearn build: 0.00058 sec
    python build : 0.11 sec
    sklearn build: 0.00033 sec
    python build : 0.053 sec
    sklearn query: 0.0049 sec
    python query : 2.3 sec
    sklearn query: 0.004 sec
    python query : 1 sec
    jakesmac $ python ball_tree_numba.py
    @@ -33,9 +35,9 @@ jakesmac $ python ball_tree_numba.py
    random seed = 2772
    results match: True True
    sklearn build: 0.00032 sec
    numba build : 0.0039 sec
    sklearn build: 0.0003 sec
    numba build : 0.00045 sec
    sklearn query: 0.0048 sec
    numba query : 0.029 sec
    sklearn query: 0.0041 sec
    numba query : 0.041 sec
    ```
    73 changes: 46 additions & 27 deletions ball_tree_numba.py
    Original file line number Diff line number Diff line change
    @@ -1,12 +1,26 @@
    import warnings
    import numpy as np
    from numba import autojit, jit, void, double, int_


    class FakeJit(object):
    def __call__(self, *args, **kwargs):
    if kwargs:
    if args:
    raise ValueError()
    else:
    return self
    else:
    return args[0]


    from numba import jit as numba_jit
    #numba_jit = FakeJit()


    #----------------------------------------------------------------------
    # Distance computations

    @jit(double(double[:,:], int_, double[:,:], int_))
    @numba_jit
    def rdist(X1, i1, X2, i2):
    d = 0
    for k in range(X1.shape[1]):
    @@ -15,17 +29,17 @@ def rdist(X1, i1, X2, i2):
    return d


    @jit(double(double[:,:], double[:], int_, double[:,:], int_))
    @numba_jit
    def min_rdist(node_centroids, node_radius, i_node, X, j):
    d = rdist(node_centroids, i_node, X, j)
    return max(0, np.sqrt(d) - node_radius[i_node]) ** 2


    #----------------------------------------------------------------------
    # Heap for distances and neighbors

    def heap_create(N, k):
    distances = np.zeros((N, k), dtype=float)
    distances.fill(np.inf)
    distances = np.full((N, k), np.inf, dtype=float)
    indices = np.zeros((N, k), dtype=int)
    return distances, indices

    @@ -36,7 +50,7 @@ def heap_sort(distances, indices):
    return distances[i, j], indices[i, j]


    @jit(void(int_, double, int_, double[:,:], int_[:,:]))
    @numba_jit
    def heap_push(row, val, i_val, distances, indices):
    size = distances.shape[1]

    @@ -83,7 +97,7 @@ def heap_push(row, val, i_val, distances, indices):
    #----------------------------------------------------------------------
    # Tools for building the tree

    @jit(void(double[:,:], int_[:], int_, int_, int_))
    @numba_jit
    def _partition_indices(data, idx_array, idx_start, idx_end, split_index):
    # Find the split dimension
    n_features = data.shape[1]
    @@ -127,8 +141,7 @@ def _partition_indices(data, idx_array, idx_start, idx_end, split_index):
    right = midindex - 1


    @jit(void(int_, int_, int_, double[:,:], double[:,:], double[:], int_[:],
    int_[:], int_[:], int_[:], int_, int_))
    @numba_jit
    def _recursive_build(i_node, idx_start, idx_end,
    data, node_centroids, node_radius, idx_array,
    node_idx_start, node_idx_end, node_is_leaf,
    @@ -152,8 +165,10 @@ def _recursive_build(i_node, idx_start, idx_end,
    node_idx_start[i_node] = idx_start
    node_idx_end[i_node] = idx_end

    i_child = 2 * i_node + 1

    # recursively create subnodes
    if 2 * i_node + 1 >= n_nodes:
    if i_child >= n_nodes:
    node_is_leaf[i_node] = True
    if idx_end - idx_start > 2 * leaf_size:
    # this shouldn't happen if our memory allocation is correct.
    @@ -172,22 +187,21 @@ def _recursive_build(i_node, idx_start, idx_end,
    else:
    # split node and recursively construct child nodes.
    node_is_leaf[i_node] = False
    n_mid = int((idx_end + idx_start) / 2)
    n_mid = int((idx_end + idx_start) // 2)
    _partition_indices(data, idx_array, idx_start, idx_end, n_mid)
    _recursive_build(2 * i_node + 1, idx_start, n_mid,
    _recursive_build(i_child, idx_start, n_mid,
    data, node_centroids, node_radius, idx_array,
    node_idx_start, node_idx_end, node_is_leaf,
    n_nodes, leaf_size)
    _recursive_build(2 * i_node + 2, n_mid, idx_end,
    _recursive_build(i_child + 1, n_mid, idx_end,
    data, node_centroids, node_radius, idx_array,
    node_idx_start, node_idx_end, node_is_leaf,
    n_nodes, leaf_size)


    #----------------------------------------------------------------------
    # Tools for querying the tree
    @jit(void(int_, double[:,:], int_, double[:,:], int_[:,:], double, double[:,:],
    int_[:], double[:,:], double[:], int_[:], int_[:], int_[:]))
    @numba_jit
    def _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB,
    data, idx_array, node_centroids, node_radius,
    node_is_leaf, node_idx_start, node_idx_end):
    @@ -261,8 +275,8 @@ def __init__(self, data, leaf_size=40):
    # determine number of levels in the tree, and from this
    # the number of nodes in the tree. This results in leaf nodes
    # with numbers of points betweeen leaf_size and 2 * leaf_size
    self.n_levels = (np.log2(max(1, (self.n_samples - 1) / self.leaf_size))
    + 1)
    self.n_levels = 1 + np.log2(max(1, ((self.n_samples - 1)
    // self.leaf_size)))
    self.n_nodes = int(2 ** self.n_levels) - 1

    # allocate arrays for storage
    @@ -318,17 +332,21 @@ def query(self, X, k=1, sort_results=True):

    #----------------------------------------------------------------------
    # Testing function

    def test_tree(N=1000, D=3, K=5, LS=40):
    from time import time
    from sklearn.neighbors import BallTree as skBallTree

    rseed = np.random.randint(10000)
    print "-------------------------------------------------------"
    print "%i neighbors of %i points in %i dimensions" % (K, N, D)
    print "random seed =", rseed
    print("-------------------------------------------------------")
    print("{0} neighbors of {1} points in {2} dimensions".format(K, N, D))
    print("random seed = {0}".format(rseed))
    np.random.seed(rseed)
    X = np.random.random((N, D))

    # pre-run to jit compile the code
    BallTree(X, leaf_size=LS).query(X, K)

    t0 = time()
    bt1 = skBallTree(X, leaf_size=LS)
    t1 = time()
    @@ -340,13 +358,14 @@ def test_tree(N=1000, D=3, K=5, LS=40):
    dist2, ind2 = bt2.query(X, K)
    t4 = time()

    print "results match:", np.allclose(dist1, dist2), np.allclose(ind1, ind2)
    print
    print "sklearn build: %.2g sec" % (t1 - t0)
    print "numba build : %.2g sec" % (t3 - t2)
    print
    print "sklearn query: %.2g sec" % (t2 - t1)
    print "numba query : %.2g sec" % (t4 - t3)
    print("results match: {0} {1}".format(np.allclose(dist1, dist2),
    np.allclose(ind1, ind2)))
    print("")
    print("sklearn build: {0:.2g} sec".format(t1 - t0))
    print("numba build : {0:.2g} sec".format(t3 - t2))
    print("")
    print("sklearn query: {0:.2g} sec".format(t2 - t1))
    print("numba query : {0:.2g} sec".format(t4 - t3))


    if __name__ == '__main__':
    32 changes: 17 additions & 15 deletions ball_tree_python.py
    Original file line number Diff line number Diff line change
    @@ -1,3 +1,4 @@
    from __future__ import division, print_function
    import numpy as np


    @@ -19,8 +20,8 @@ def __init__(self, data, leaf_size=40):
    # determine number of levels in the tree, and from this
    # the number of nodes in the tree. This results in leaf nodes
    # with numbers of points betweeen leaf_size and 2 * leaf_size
    self.n_levels = (np.log2(max(1, (self.n_samples - 1) / self.leaf_size))
    + 1)
    self.n_levels = 1 + np.log2(max(1, ((self.n_samples - 1)
    // self.leaf_size)))
    self.n_nodes = int(2 ** self.n_levels) - 1

    # allocate arrays for storage
    @@ -60,7 +61,7 @@ def _recursive_build(self, i_node, idx_start, idx_end):
    else:
    # split node and recursively construct child nodes.
    self.node_is_leaf[i_node] = False
    n_mid = int((idx_end + idx_start) / 2)
    n_mid = int((idx_end + idx_start) // 2)
    _partition_indices(self.data, self.idx_array,
    idx_start, idx_end, n_mid)
    self._recursive_build(2 * i_node + 1, idx_start, n_mid)
    @@ -71,8 +72,8 @@ def init_node(self, i_node, idx_start, idx_end):
    for j in range(self.n_features):
    self.node_centroids[i_node, j] = 0
    for i in range(idx_start, idx_end):
    self.node_centroids[i_node, j] +=\
    self.data[self.idx_array[i], j]
    self.node_centroids[i_node, j] += self.data[self.idx_array[i],
    j]
    self.node_centroids[i_node, j] /= (idx_end - idx_start)

    # determine Node radius
    @@ -270,9 +271,9 @@ def test_tree(N=1000, D=3, K=5, LS=40):
    from sklearn.neighbors import BallTree as skBallTree

    rseed = np.random.randint(10000)
    print "-------------------------------------------------------"
    print "%i neighbors of %i points in %i dimensions" % (K, N, D)
    print "random seed =", rseed
    print("-------------------------------------------------------")
    print("{0} neighbors of {1} points in {2} dimensions".format(K, N, D))
    print("random seed = {0}".format(rseed))
    np.random.seed(rseed)
    X = np.random.random((N, D))

    @@ -287,13 +288,14 @@ def test_tree(N=1000, D=3, K=5, LS=40):
    dist2, ind2 = bt2.query(X, K)
    t4 = time()

    print "results match:", np.allclose(dist1, dist2), np.allclose(ind1, ind2)
    print
    print "sklearn build: %.2g sec" % (t1 - t0)
    print "python build : %.2g sec" % (t3 - t2)
    print
    print "sklearn query: %.2g sec" % (t2 - t1)
    print "python query : %.2g sec" % (t4 - t3)
    print("results match: {0} {1}".format(np.allclose(dist1, dist2),
    np.allclose(ind1, ind2)))
    print("")
    print("sklearn build: {0:.2g} sec".format(t1 - t0))
    print("python build : {0:.2g} sec".format(t3 - t2))
    print("")
    print("sklearn query: {0:.2g} sec".format(t2 - t1))
    print("python query : {0:.2g} sec".format(t4 - t3))


    if __name__ == '__main__':
  2. jakevdp revised this gist Mar 21, 2013. 3 changed files with 4 additions and 4 deletions.
    4 changes: 2 additions & 2 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -16,7 +16,7 @@ These are the results on my Macbook Pro:
    ```
    jakesmac $ python ball_tree_python.py
    -------------------------------------------------------
    1000 neighbors of 3 points in 5 dimensions
    5 neighbors of 1000 points in 3 dimensions
    random seed = 9742
    results match: True True
    @@ -29,7 +29,7 @@ python query : 2.3 sec
    jakesmac $ python ball_tree_numba.py
    -------------------------------------------------------
    1000 neighbors of 3 points in 5 dimensions
    5 neighbors of 1000 points in 3 dimensions
    random seed = 2772
    results match: True True
    2 changes: 1 addition & 1 deletion ball_tree_numba.py
    Original file line number Diff line number Diff line change
    @@ -324,7 +324,7 @@ def test_tree(N=1000, D=3, K=5, LS=40):

    rseed = np.random.randint(10000)
    print "-------------------------------------------------------"
    print "%i neighbors of %i points in %i dimensions" % (N, D, K)
    print "%i neighbors of %i points in %i dimensions" % (K, N, D)
    print "random seed =", rseed
    np.random.seed(rseed)
    X = np.random.random((N, D))
    2 changes: 1 addition & 1 deletion ball_tree_python.py
    Original file line number Diff line number Diff line change
    @@ -271,7 +271,7 @@ def test_tree(N=1000, D=3, K=5, LS=40):

    rseed = np.random.randint(10000)
    print "-------------------------------------------------------"
    print "%i neighbors of %i points in %i dimensions" % (N, D, K)
    print "%i neighbors of %i points in %i dimensions" % (K, N, D)
    print "random seed =", rseed
    np.random.seed(rseed)
    X = np.random.random((N, D))
  3. jakevdp revised this gist Mar 21, 2013. 1 changed file with 353 additions and 0 deletions.
    353 changes: 353 additions & 0 deletions ball_tree_numba.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,353 @@
    import warnings
    import numpy as np
    from numba import autojit, jit, void, double, int_


    #----------------------------------------------------------------------
    # Distance computations

    @jit(double(double[:,:], int_, double[:,:], int_))
    def rdist(X1, i1, X2, i2):
    d = 0
    for k in range(X1.shape[1]):
    tmp = (X1[i1, k] - X2[i2, k])
    d += tmp * tmp
    return d


    @jit(double(double[:,:], double[:], int_, double[:,:], int_))
    def min_rdist(node_centroids, node_radius, i_node, X, j):
    d = rdist(node_centroids, i_node, X, j)
    return max(0, np.sqrt(d) - node_radius[i_node]) ** 2

    #----------------------------------------------------------------------
    # Heap for distances and neighbors

    def heap_create(N, k):
    distances = np.zeros((N, k), dtype=float)
    distances.fill(np.inf)
    indices = np.zeros((N, k), dtype=int)
    return distances, indices


    def heap_sort(distances, indices):
    i = np.arange(len(distances), dtype=int)[:, None]
    j = np.argsort(distances, 1)
    return distances[i, j], indices[i, j]


    @jit(void(int_, double, int_, double[:,:], int_[:,:]))
    def heap_push(row, val, i_val, distances, indices):
    size = distances.shape[1]

    # check if val should be in heap
    if val > distances[row, 0]:
    return

    # insert val at position zero
    distances[row, 0] = val
    indices[row, 0] = i_val

    #descend the heap, swapping values until the max heap criterion is met
    i = 0
    while True:
    ic1 = 2 * i + 1
    ic2 = ic1 + 1

    if ic1 >= size:
    break
    elif ic2 >= size:
    if distances[row, ic1] > val:
    i_swap = ic1
    else:
    break
    elif distances[row, ic1] >= distances[row, ic2]:
    if val < distances[row, ic1]:
    i_swap = ic1
    else:
    break
    else:
    if val < distances[row, ic2]:
    i_swap = ic2
    else:
    break

    distances[row, i] = distances[row, i_swap]
    indices[row, i] = indices[row, i_swap]

    i = i_swap

    distances[row, i] = val
    indices[row, i] = i_val

    #----------------------------------------------------------------------
    # Tools for building the tree

    @jit(void(double[:,:], int_[:], int_, int_, int_))
    def _partition_indices(data, idx_array, idx_start, idx_end, split_index):
    # Find the split dimension
    n_features = data.shape[1]

    split_dim = 0
    max_spread = 0

    for j in range(n_features):
    max_val = -np.inf
    min_val = np.inf
    for i in range(idx_start, idx_end):
    val = data[idx_array[i], j]
    max_val = max(max_val, val)
    min_val = min(min_val, val)
    if max_val - min_val > max_spread:
    max_spread = max_val - min_val
    split_dim = j

    # Partition using the split dimension
    left = idx_start
    right = idx_end - 1

    while True:
    midindex = left
    for i in range(left, right):
    d1 = data[idx_array[i], split_dim]
    d2 = data[idx_array[right], split_dim]
    if d1 < d2:
    tmp = idx_array[i]
    idx_array[i] = idx_array[midindex]
    idx_array[midindex] = tmp
    midindex += 1
    tmp = idx_array[midindex]
    idx_array[midindex] = idx_array[right]
    idx_array[right] = tmp
    if midindex == split_index:
    break
    elif midindex < split_index:
    left = midindex + 1
    else:
    right = midindex - 1


    @jit(void(int_, int_, int_, double[:,:], double[:,:], double[:], int_[:],
    int_[:], int_[:], int_[:], int_, int_))
    def _recursive_build(i_node, idx_start, idx_end,
    data, node_centroids, node_radius, idx_array,
    node_idx_start, node_idx_end, node_is_leaf,
    n_nodes, leaf_size):
    # determine Node centroid
    for j in range(data.shape[1]):
    node_centroids[i_node, j] = 0
    for i in range(idx_start, idx_end):
    node_centroids[i_node, j] += data[idx_array[i], j]
    node_centroids[i_node, j] /= (idx_end - idx_start)

    # determine Node radius
    sq_radius = 0.0
    for i in range(idx_start, idx_end):
    sq_dist = rdist(node_centroids, i_node, data, idx_array[i])
    if sq_dist > sq_radius:
    sq_radius = sq_dist

    # set node properties
    node_radius[i_node] = np.sqrt(sq_radius)
    node_idx_start[i_node] = idx_start
    node_idx_end[i_node] = idx_end

    # recursively create subnodes
    if 2 * i_node + 1 >= n_nodes:
    node_is_leaf[i_node] = True
    if idx_end - idx_start > 2 * leaf_size:
    # this shouldn't happen if our memory allocation is correct.
    # We'll proactively prevent memory errors, but raise a
    # warning saying we're doing so.
    warnings.warn("Internal: memory layout is flawed: "
    "not enough nodes allocated")
    pass

    elif idx_end - idx_start < 2:
    # again, this shouldn't happen if our memory allocation is correct.
    warnings.warn("Internal: memory layout is flawed: "
    "too many nodes allocated")
    node_is_leaf[i_node] = True

    else:
    # split node and recursively construct child nodes.
    node_is_leaf[i_node] = False
    n_mid = int((idx_end + idx_start) / 2)
    _partition_indices(data, idx_array, idx_start, idx_end, n_mid)
    _recursive_build(2 * i_node + 1, idx_start, n_mid,
    data, node_centroids, node_radius, idx_array,
    node_idx_start, node_idx_end, node_is_leaf,
    n_nodes, leaf_size)
    _recursive_build(2 * i_node + 2, n_mid, idx_end,
    data, node_centroids, node_radius, idx_array,
    node_idx_start, node_idx_end, node_is_leaf,
    n_nodes, leaf_size)


    #----------------------------------------------------------------------
    # Tools for querying the tree
    @jit(void(int_, double[:,:], int_, double[:,:], int_[:,:], double, double[:,:],
    int_[:], double[:,:], double[:], int_[:], int_[:], int_[:]))
    def _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB,
    data, idx_array, node_centroids, node_radius,
    node_is_leaf, node_idx_start, node_idx_end):
    #------------------------------------------------------------
    # Case 1: query point is outside node radius:
    # trim it from the query
    if sq_dist_LB > heap_distances[i_pt, 0]:
    pass

    #------------------------------------------------------------
    # Case 2: this is a leaf node. Update set of nearby points
    elif node_is_leaf[i_node]:
    for i in range(node_idx_start[i_node],
    node_idx_end[i_node]):
    dist_pt = rdist(data, idx_array[i], X, i_pt)
    if dist_pt < heap_distances[i_pt, 0]:
    heap_push(i_pt, dist_pt, idx_array[i],
    heap_distances, heap_indices)

    #------------------------------------------------------------
    # Case 3: Node is not a leaf. Recursively query subnodes
    # starting with the closest
    else:
    i1 = 2 * i_node + 1
    i2 = i1 + 1
    sq_dist_LB_1 = min_rdist(node_centroids,
    node_radius,
    i1, X, i_pt)
    sq_dist_LB_2 = min_rdist(node_centroids,
    node_radius,
    i2, X, i_pt)

    # recursively query subnodes
    if sq_dist_LB_1 <= sq_dist_LB_2:
    _query_recursive(i1, X, i_pt, heap_distances,
    heap_indices, sq_dist_LB_1,
    data, idx_array, node_centroids, node_radius,
    node_is_leaf, node_idx_start, node_idx_end)
    _query_recursive(i2, X, i_pt, heap_distances,
    heap_indices, sq_dist_LB_2,
    data, idx_array, node_centroids, node_radius,
    node_is_leaf, node_idx_start, node_idx_end)
    else:
    _query_recursive(i2, X, i_pt, heap_distances,
    heap_indices, sq_dist_LB_2,
    data, idx_array, node_centroids, node_radius,
    node_is_leaf, node_idx_start, node_idx_end)
    _query_recursive(i1, X, i_pt, heap_distances,
    heap_indices, sq_dist_LB_1,
    data, idx_array, node_centroids, node_radius,
    node_is_leaf, node_idx_start, node_idx_end)


    #----------------------------------------------------------------------
    # The Ball Tree object
    class BallTree(object):
    def __init__(self, data, leaf_size=40):
    self.data = data
    self.leaf_size = leaf_size

    # validate data
    if self.data.size == 0:
    raise ValueError("X is an empty array")

    if leaf_size < 1:
    raise ValueError("leaf_size must be greater than or equal to 1")

    self.n_samples = self.data.shape[0]
    self.n_features = self.data.shape[1]

    # determine number of levels in the tree, and from this
    # the number of nodes in the tree. This results in leaf nodes
    # with numbers of points betweeen leaf_size and 2 * leaf_size
    self.n_levels = (np.log2(max(1, (self.n_samples - 1) / self.leaf_size))
    + 1)
    self.n_nodes = int(2 ** self.n_levels) - 1

    # allocate arrays for storage
    self.idx_array = np.arange(self.n_samples, dtype=int)
    self.node_radius = np.zeros(self.n_nodes, dtype=float)
    self.node_idx_start = np.zeros(self.n_nodes, dtype=int)
    self.node_idx_end = np.zeros(self.n_nodes, dtype=int)
    self.node_is_leaf = np.zeros(self.n_nodes, dtype=int)
    self.node_centroids = np.zeros((self.n_nodes, self.n_features),
    dtype=float)

    # Allocate tree-specific data from TreeBase
    _recursive_build(0, 0, self.n_samples,
    self.data, self.node_centroids,
    self.node_radius, self.idx_array,
    self.node_idx_start, self.node_idx_end,
    self.node_is_leaf, self.n_nodes, self.leaf_size)

    def query(self, X, k=1, sort_results=True):
    X = np.asarray(X, dtype=float)

    if X.shape[-1] != self.n_features:
    raise ValueError("query data dimension must "
    "match training data dimension")

    if self.data.shape[0] < k:
    raise ValueError("k must be less than or equal "
    "to the number of training points")

    # flatten X, and save original shape information
    Xshape = X.shape
    X = X.reshape((-1, self.data.shape[1]))

    # initialize heap for neighbors
    heap_distances, heap_indices = heap_create(X.shape[0], k)

    for i in range(X.shape[0]):
    sq_dist_LB = min_rdist(self.node_centroids,
    self.node_radius,
    0, X, i)
    _query_recursive(0, X, i, heap_distances, heap_indices, sq_dist_LB,
    self.data, self.idx_array, self.node_centroids,
    self.node_radius, self.node_is_leaf,
    self.node_idx_start, self.node_idx_end)

    distances, indices = heap_sort(heap_distances, heap_indices)
    distances = np.sqrt(distances)

    # deflatten results
    return (distances.reshape(Xshape[:-1] + (k,)),
    indices.reshape(Xshape[:-1] + (k,)))


    #----------------------------------------------------------------------
    # Testing function
    def test_tree(N=1000, D=3, K=5, LS=40):
    from time import time
    from sklearn.neighbors import BallTree as skBallTree

    rseed = np.random.randint(10000)
    print "-------------------------------------------------------"
    print "%i neighbors of %i points in %i dimensions" % (N, D, K)
    print "random seed =", rseed
    np.random.seed(rseed)
    X = np.random.random((N, D))

    t0 = time()
    bt1 = skBallTree(X, leaf_size=LS)
    t1 = time()
    dist1, ind1 = bt1.query(X, K)
    t2 = time()

    bt2 = BallTree(X, leaf_size=LS)
    t3 = time()
    dist2, ind2 = bt2.query(X, K)
    t4 = time()

    print "results match:", np.allclose(dist1, dist2), np.allclose(ind1, ind2)
    print
    print "sklearn build: %.2g sec" % (t1 - t0)
    print "numba build : %.2g sec" % (t3 - t2)
    print
    print "sklearn query: %.2g sec" % (t2 - t1)
    print "numba query : %.2g sec" % (t4 - t3)


    if __name__ == '__main__':
    test_tree()
  4. jakevdp created this gist Mar 21, 2013.
    41 changes: 41 additions & 0 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,41 @@
    Numba Ball Tree
    ================
    This is a quick attempt at writing a ball tree for nearest neighbor
    searches using numba. I've included a pure python version, and a
    version with numba jit decorators. Because class support in numba
    is not yet complete, all the code is factored out to stand-alone
    functions in the numba version. The resulting code produced by
    numba is about ~10 times slower than the cython ball tree in
    scikit-learn. My guess is that part of this stems from lack of
    inlining in numba, while the rest is due to some sort of overhead
    that's not obvious to me.

    Timings
    -------
    These are the results on my Macbook Pro:
    ```
    jakesmac $ python ball_tree_python.py
    -------------------------------------------------------
    1000 neighbors of 3 points in 5 dimensions
    random seed = 9742
    results match: True True
    sklearn build: 0.00058 sec
    python build : 0.11 sec
    sklearn query: 0.0049 sec
    python query : 2.3 sec
    jakesmac $ python ball_tree_numba.py
    -------------------------------------------------------
    1000 neighbors of 3 points in 5 dimensions
    random seed = 2772
    results match: True True
    sklearn build: 0.00032 sec
    numba build : 0.0039 sec
    sklearn query: 0.0048 sec
    numba query : 0.029 sec
    ```
    300 changes: 300 additions & 0 deletions ball_tree_python.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,300 @@
    import numpy as np


    class BallTree(object):
    def __init__(self, data, leaf_size=40):
    self.data = data
    self.leaf_size = leaf_size

    # validate data
    if self.data.size == 0:
    raise ValueError("X is an empty array")

    if leaf_size < 1:
    raise ValueError("leaf_size must be greater than or equal to 1")

    self.n_samples = self.data.shape[0]
    self.n_features = self.data.shape[1]

    # determine number of levels in the tree, and from this
    # the number of nodes in the tree. This results in leaf nodes
    # with numbers of points betweeen leaf_size and 2 * leaf_size
    self.n_levels = (np.log2(max(1, (self.n_samples - 1) / self.leaf_size))
    + 1)
    self.n_nodes = int(2 ** self.n_levels) - 1

    # allocate arrays for storage
    self.idx_array = np.arange(self.n_samples, dtype=int)
    self.node_radius = np.zeros(self.n_nodes, dtype=float)
    self.node_idx_start = np.zeros(self.n_nodes, dtype=int)
    self.node_idx_end = np.zeros(self.n_nodes, dtype=int)
    self.node_is_leaf = np.zeros(self.n_nodes, dtype=int)
    self.node_centroids = np.zeros((self.n_nodes, self.n_features),
    dtype=float)

    # Allocate tree-specific data from TreeBase
    self._recursive_build(0, 0, self.n_samples)

    def _recursive_build(self, i_node, idx_start, idx_end):
    # initialize node data
    self.init_node(i_node, idx_start, idx_end)

    if 2 * i_node + 1 >= self.n_nodes:
    self.node_is_leaf[i_node] = True
    if idx_end - idx_start > 2 * self.leaf_size:
    # this shouldn't happen if our memory allocation is correct
    # we'll proactively prevent memory errors, but raise a
    # warning saying we're doing so.
    import warnings
    warnings.warn("Internal: memory layout is flawed: "
    "not enough nodes allocated")

    elif idx_end - idx_start < 2:
    # again, this shouldn't happen if our memory allocation
    # is correct. Raise a warning.
    import warnings
    warnings.warn("Internal: memory layout is flawed: "
    "too many nodes allocated")
    self.node_is_leaf[i_node] = True

    else:
    # split node and recursively construct child nodes.
    self.node_is_leaf[i_node] = False
    n_mid = int((idx_end + idx_start) / 2)
    _partition_indices(self.data, self.idx_array,
    idx_start, idx_end, n_mid)
    self._recursive_build(2 * i_node + 1, idx_start, n_mid)
    self._recursive_build(2 * i_node + 2, n_mid, idx_end)

    def init_node(self, i_node, idx_start, idx_end):
    # determine Node centroid
    for j in range(self.n_features):
    self.node_centroids[i_node, j] = 0
    for i in range(idx_start, idx_end):
    self.node_centroids[i_node, j] +=\
    self.data[self.idx_array[i], j]
    self.node_centroids[i_node, j] /= (idx_end - idx_start)

    # determine Node radius
    sq_radius = 0
    for i in range(idx_start, idx_end):
    sq_dist = self.rdist(self.node_centroids, i_node,
    self.data, self.idx_array[i])
    sq_radius = max(sq_radius, sq_dist)

    self.node_radius[i_node] = np.sqrt(sq_radius)
    self.node_idx_start[i_node] = idx_start
    self.node_idx_end[i_node] = idx_end

    nbrhd = self.data[self.idx_array[idx_start:idx_end]]

    def rdist(self, X1, i1, X2, i2):
    d = 0
    for k in range(self.n_features):
    tmp = (X1[i1, k] - X2[i2, k])
    d += tmp * tmp
    return d

    def min_rdist(self, i_node, X, j):
    d = self.rdist(self.node_centroids, i_node, X, j)
    return max(0, np.sqrt(d) - self.node_radius[i_node]) ** 2

    def query(self, X, k=1, sort_results=True):
    X = np.asarray(X, dtype=float)

    if X.shape[-1] != self.n_features:
    raise ValueError("query data dimension must "
    "match training data dimension")

    if self.data.shape[0] < k:
    raise ValueError("k must be less than or equal "
    "to the number of training points")

    # flatten X, and save original shape information
    Xshape = X.shape
    X = X.reshape((-1, self.data.shape[1]))

    # initialize heap for neighbors
    heap = NeighborsHeap(X.shape[0], k)

    for i in range(X.shape[0]):
    sq_dist_LB = self.min_rdist(0, X, i)
    self._query_recursive(0, X, i, heap, sq_dist_LB)

    distances, indices = heap.get_arrays(sort=sort_results)
    distances = np.sqrt(distances)

    # deflatten results
    return (distances.reshape(Xshape[:-1] + (k,)),
    indices.reshape(Xshape[:-1] + (k,)))

    def _query_recursive(self, i_node, X, i_pt, heap, sq_dist_LB):
    #------------------------------------------------------------
    # Case 1: query point is outside node radius:
    # trim it from the query
    if sq_dist_LB > heap.largest(i_pt):
    pass

    #------------------------------------------------------------
    # Case 2: this is a leaf node. Update set of nearby points
    elif self.node_is_leaf[i_node]:
    for i in range(self.node_idx_start[i_node],
    self.node_idx_end[i_node]):
    dist_pt = self.rdist(self.data, self.idx_array[i], X, i_pt)
    if dist_pt < heap.largest(i_pt):
    heap.push(i_pt, dist_pt, self.idx_array[i])

    #------------------------------------------------------------
    # Case 3: Node is not a leaf. Recursively query subnodes
    # starting with the closest
    else:
    i1 = 2 * i_node + 1
    i2 = i1 + 1
    sq_dist_LB_1 = self.min_rdist(i1, X, i_pt)
    sq_dist_LB_2 = self.min_rdist(i2, X, i_pt)

    # recursively query subnodes
    if sq_dist_LB_1 <= sq_dist_LB_2:
    self._query_recursive(i1, X, i_pt, heap, sq_dist_LB_1)
    self._query_recursive(i2, X, i_pt, heap, sq_dist_LB_2)
    else:
    self._query_recursive(i2, X, i_pt, heap, sq_dist_LB_2)
    self._query_recursive(i1, X, i_pt, heap, sq_dist_LB_1)


    def _partition_indices(data, idx_array, idx_start, idx_end, split_index):
    # Find the split dimension
    n_features = data.shape[1]

    split_dim = 0
    max_spread = 0

    for j in range(n_features):
    max_val = -np.inf
    min_val = np.inf
    for i in range(idx_start, idx_end):
    val = data[idx_array[i], j]
    max_val = max(max_val, val)
    min_val = min(min_val, val)
    if max_val - min_val > max_spread:
    max_spread = max_val - min_val
    split_dim = j

    # Partition using the split dimension
    left = idx_start
    right = idx_end - 1

    while True:
    midindex = left
    for i in range(left, right):
    d1 = data[idx_array[i], split_dim]
    d2 = data[idx_array[right], split_dim]
    if d1 < d2:
    tmp = idx_array[i]
    idx_array[i] = idx_array[midindex]
    idx_array[midindex] = tmp
    midindex += 1
    tmp = idx_array[midindex]
    idx_array[midindex] = idx_array[right]
    idx_array[right] = tmp
    if midindex == split_index:
    break
    elif midindex < split_index:
    left = midindex + 1
    else:
    right = midindex - 1


    class NeighborsHeap:
    def __init__(self, n_pts, n_nbrs):
    self.distances = np.zeros((n_pts, n_nbrs), dtype=float) + np.inf
    self.indices = np.zeros((n_pts, n_nbrs), dtype=int)

    def get_arrays(self, sort=True):
    if sort:
    i = np.arange(len(self.distances), dtype=int)[:, None]
    j = np.argsort(self.distances, 1)
    return self.distances[i, j], self.indices[i, j]
    else:
    return self.distances, self.indices

    def largest(self, row):
    return self.distances[row, 0]

    def push(self, row, val, i_val):
    size = self.distances.shape[1]

    # check if val should be in heap
    if val > self.distances[row, 0]:
    return

    # insert val at position zero
    self.distances[row, 0] = val
    self.indices[row, 0] = i_val

    #descend the heap, swapping values until the max heap criterion is met
    i = 0
    while True:
    ic1 = 2 * i + 1
    ic2 = ic1 + 1

    if ic1 >= size:
    break
    elif ic2 >= size:
    if self.distances[row, ic1] > val:
    i_swap = ic1
    else:
    break
    elif self.distances[row, ic1] >= self.distances[row, ic2]:
    if val < self.distances[row, ic1]:
    i_swap = ic1
    else:
    break
    else:
    if val < self.distances[row, ic2]:
    i_swap = ic2
    else:
    break

    self.distances[row, i] = self.distances[row, i_swap]
    self.indices[row, i] = self.indices[row, i_swap]

    i = i_swap

    self.distances[row, i] = val
    self.indices[row, i] = i_val


    def test_tree(N=1000, D=3, K=5, LS=40):
    from time import time
    from sklearn.neighbors import BallTree as skBallTree

    rseed = np.random.randint(10000)
    print "-------------------------------------------------------"
    print "%i neighbors of %i points in %i dimensions" % (N, D, K)
    print "random seed =", rseed
    np.random.seed(rseed)
    X = np.random.random((N, D))

    t0 = time()
    bt1 = skBallTree(X, leaf_size=LS)
    t1 = time()
    dist1, ind1 = bt1.query(X, K)
    t2 = time()

    bt2 = BallTree(X, leaf_size=LS)
    t3 = time()
    dist2, ind2 = bt2.query(X, K)
    t4 = time()

    print "results match:", np.allclose(dist1, dist2), np.allclose(ind1, ind2)
    print
    print "sklearn build: %.2g sec" % (t1 - t0)
    print "python build : %.2g sec" % (t3 - t2)
    print
    print "sklearn query: %.2g sec" % (t2 - t1)
    print "python query : %.2g sec" % (t4 - t3)


    if __name__ == '__main__':
    test_tree()