Last active
June 19, 2019 05:16
-
-
Save turian/b298e41792ec5d402e8df73985bc5bf4 to your computer and use it in GitHub Desktop.
Tune search_k for annoy library.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Tune search_k for annoy library. | |
This is for people who want their nearest-neighbors to be above a | |
certain threshold of precision, but otherwise want it as fast as | |
possible. | |
AUTHOR: Joseph Turian | |
LICENSE: Apache License 2.0 | |
""" | |
import sys | |
import random | |
import numpy as np | |
import scipy.stats | |
def choose_search_k(t, n, desired_precision=0.99, | |
seed=None, | |
max_search_k_factor=1, | |
samples_for_evaluation=100, | |
multiplicative_step_factor=2, | |
verbose=False | |
): | |
""" | |
Choose a search_k in a data-driven way, using statistical tests. | |
Given an annoy index t, if we know that we are searching for n | |
nearest neighbors, find the smallest search_k that is likely | |
to give better than desired_precision against the "optimal" | |
search_k. | |
The algorithm is that we first try to find the 'max_search_k' | |
where using search_k=max_search_k and | |
search_k=max_search_k/(multiplicative_step_factor**2) are highly | |
correlated. This will rarely happen for low search_k. | |
Once we find this max_search_k, we then try decreasing search_k | |
by multiplicative_step_factor and finding the lowest search_k | |
that still has high precision with max_search_k. | |
Seed is the RNG seed. | |
max_search_k_factor is the starting factor for the max_search_k. | |
You probably shouldn't touch this an end-user. | |
samples_for_evaluation is how many random documents are sampled | |
for evaluation. Don't go lower than 30. | |
multiplicative_step_factor is how big each *downward* step size is. | |
verbose is useful for really understanding the effect of search_k | |
on performance. | |
""" | |
# Our current guess of the "optimal" search_k value | |
max_search_k = int(n * t.get_n_trees() * max_search_k_factor) | |
# Pick random item indices to do the search_k evaluation | |
rng = random.Random() | |
rng.seed(seed) | |
idxs_to_evaluate = [rng.randint(0, t.get_n_items()-1) for i in \ | |
range(samples_for_evaluation)] | |
def _all_nns(search_k): | |
for idx in idxs_to_evaluate: | |
yield t.get_nns_by_item(idx, n, search_k=search_k, include_distances=False) | |
max_search_k_nns = [nns for nns in _all_nns(max_search_k)] | |
def _precision_of_search_k_with_max(search_k): | |
""" | |
For a particular search_k, find its average precision | |
(over the sampled items) versus using max_search_k | |
""" | |
search_k_nns = [nns for nns in _all_nns(search_k)] | |
precisions = [] | |
for r1, r2 in zip(max_search_k_nns, search_k_nns): | |
precisions.append(len(set(r1) & set(r2)) / len(r1)) | |
return np.mean(precisions) | |
current_search_k = int(max_search_k / (multiplicative_step_factor**2)) | |
current_precision = _precision_of_search_k_with_max(current_search_k) | |
if verbose: | |
print("precision=%.3f with search_k=%d (max_search_k=%d)" % \ | |
(current_precision, current_search_k, max_search_k), file=sys.stderr) | |
if current_precision < desired_precision: | |
# Uh oh, max_search_k is too slow! So let's increase it. | |
if verbose: | |
print("max_search_k of %d (max_search_k_factor = %d) is too low, precision was %.3f" \ | |
% (max_search_k, max_search_k_factor, current_precision), file=sys.stderr) | |
return choose_search_k(t, n, desired_precision, | |
seed, | |
# We jump up by multiplicative_step_factor squared, | |
# to make sure we really capture what max_search_k | |
# should be. | |
max_search_k_factor=max_search_k_factor * \ | |
(multiplicative_step_factor ** 2), | |
samples_for_evaluation=samples_for_evaluation, | |
multiplicative_step_factor=multiplicative_step_factor, | |
verbose=verbose) | |
else: | |
lowest_acceptable_search_k = current_search_k | |
while current_search_k > 1: | |
current_search_k = int(current_search_k / multiplicative_step_factor) | |
current_precision = _precision_of_search_k_with_max(current_search_k) | |
if verbose: | |
print("precision=%.3f with search_k=%d (max_search_k=%d)" % \ | |
(current_precision, current_search_k, max_search_k), file=sys.stderr) | |
if current_precision < desired_precision: | |
if verbose: | |
# For verbose output, plot the whole curve | |
pass | |
else: | |
break | |
else: | |
lowest_acceptable_search_k = current_search_k | |
return lowest_acceptable_search_k | |
if __name__ == "__main__": | |
# Here's some code to demo how it works | |
from annoy import AnnoyIndex | |
f = 40 | |
trees = 10 | |
t = AnnoyIndex(f) # Length of item vector that will be indexed | |
random.seed(0) | |
for i in range(100000): | |
v = [random.gauss(0, 1) for z in range(f)] | |
t.add_item(i, v) | |
t.build(trees) | |
chosen_search_k = choose_search_k(t, 32, seed=0, verbose=True) | |
print("Use this search_k=%d" % chosen_search_k) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment