Skip to content

Instantly share code, notes, and snippets.

@turian
Last active June 19, 2019 05:16
Show Gist options
  • Save turian/b298e41792ec5d402e8df73985bc5bf4 to your computer and use it in GitHub Desktop.
Save turian/b298e41792ec5d402e8df73985bc5bf4 to your computer and use it in GitHub Desktop.
Tune search_k for annoy library.
"""
Tune search_k for annoy library.
This is for people who want their nearest-neighbors to be above a
certain threshold of precision, but otherwise want it as fast as
possible.
AUTHOR: Joseph Turian
LICENSE: Apache License 2.0
"""
import sys
import random
import numpy as np
import scipy.stats
def choose_search_k(t, n, desired_precision=0.99,
seed=None,
max_search_k_factor=1,
samples_for_evaluation=100,
multiplicative_step_factor=2,
verbose=False
):
"""
Choose a search_k in a data-driven way, using statistical tests.
Given an annoy index t, if we know that we are searching for n
nearest neighbors, find the smallest search_k that is likely
to give better than desired_precision against the "optimal"
search_k.
The algorithm is that we first try to find the 'max_search_k'
where using search_k=max_search_k and
search_k=max_search_k/(multiplicative_step_factor**2) are highly
correlated. This will rarely happen for low search_k.
Once we find this max_search_k, we then try decreasing search_k
by multiplicative_step_factor and finding the lowest search_k
that still has high precision with max_search_k.
Seed is the RNG seed.
max_search_k_factor is the starting factor for the max_search_k.
You probably shouldn't touch this an end-user.
samples_for_evaluation is how many random documents are sampled
for evaluation. Don't go lower than 30.
multiplicative_step_factor is how big each *downward* step size is.
verbose is useful for really understanding the effect of search_k
on performance.
"""
# Our current guess of the "optimal" search_k value
max_search_k = int(n * t.get_n_trees() * max_search_k_factor)
# Pick random item indices to do the search_k evaluation
rng = random.Random()
rng.seed(seed)
idxs_to_evaluate = [rng.randint(0, t.get_n_items()-1) for i in \
range(samples_for_evaluation)]
def _all_nns(search_k):
for idx in idxs_to_evaluate:
yield t.get_nns_by_item(idx, n, search_k=search_k, include_distances=False)
max_search_k_nns = [nns for nns in _all_nns(max_search_k)]
def _precision_of_search_k_with_max(search_k):
"""
For a particular search_k, find its average precision
(over the sampled items) versus using max_search_k
"""
search_k_nns = [nns for nns in _all_nns(search_k)]
precisions = []
for r1, r2 in zip(max_search_k_nns, search_k_nns):
precisions.append(len(set(r1) & set(r2)) / len(r1))
return np.mean(precisions)
current_search_k = int(max_search_k / (multiplicative_step_factor**2))
current_precision = _precision_of_search_k_with_max(current_search_k)
if verbose:
print("precision=%.3f with search_k=%d (max_search_k=%d)" % \
(current_precision, current_search_k, max_search_k), file=sys.stderr)
if current_precision < desired_precision:
# Uh oh, max_search_k is too slow! So let's increase it.
if verbose:
print("max_search_k of %d (max_search_k_factor = %d) is too low, precision was %.3f" \
% (max_search_k, max_search_k_factor, current_precision), file=sys.stderr)
return choose_search_k(t, n, desired_precision,
seed,
# We jump up by multiplicative_step_factor squared,
# to make sure we really capture what max_search_k
# should be.
max_search_k_factor=max_search_k_factor * \
(multiplicative_step_factor ** 2),
samples_for_evaluation=samples_for_evaluation,
multiplicative_step_factor=multiplicative_step_factor,
verbose=verbose)
else:
lowest_acceptable_search_k = current_search_k
while current_search_k > 1:
current_search_k = int(current_search_k / multiplicative_step_factor)
current_precision = _precision_of_search_k_with_max(current_search_k)
if verbose:
print("precision=%.3f with search_k=%d (max_search_k=%d)" % \
(current_precision, current_search_k, max_search_k), file=sys.stderr)
if current_precision < desired_precision:
if verbose:
# For verbose output, plot the whole curve
pass
else:
break
else:
lowest_acceptable_search_k = current_search_k
return lowest_acceptable_search_k
if __name__ == "__main__":
# Here's some code to demo how it works
from annoy import AnnoyIndex
f = 40
trees = 10
t = AnnoyIndex(f) # Length of item vector that will be indexed
random.seed(0)
for i in range(100000):
v = [random.gauss(0, 1) for z in range(f)]
t.add_item(i, v)
t.build(trees)
chosen_search_k = choose_search_k(t, 32, seed=0, verbose=True)
print("Use this search_k=%d" % chosen_search_k)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment