Last active
February 28, 2022 00:34
-
-
Save kwang2049/d23550604059ed1576ac6cffb7e09fb2 to your computer and use it in GitHub Desktop.
Example script that shows how to index the DPR single-nq (https://github.com/facebookresearch/DPR) embeddings with Faiss IndexIVFScalarQuantizer index.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
import os | |
import json | |
import faiss | |
import tqdm | |
import numpy as np | |
import pytrec_eval | |
import time | |
from typing import List, Tuple | |
from collections import defaultdict | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig( | |
format='%(asctime)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S', | |
level=logging.INFO | |
) | |
################## download data (embeddings) ################## | |
embedding_urls = [ | |
"https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single/nq/wiki_passages_{}".format(i) | |
for i in range(50) | |
] | |
for url in embedding_urls: | |
os.system(f'wget {url}') | |
################## load pre-computed data ################## | |
logger.info('>>> Loading document embeddings') | |
xb = [] | |
ids = [] | |
findex_base = './wikipedia_passages_<i>.pkl' | |
findexes = [findex_base.replace('<i>', str(i)) for i in range(50)] # It should be 50!!! | |
for findex in tqdm.tqdm(findexes): | |
with open(findex, 'rb') as f: | |
context_embs: List[Tuple[int, np.ndarray]] = pickle.load(f) | |
for id, emb in context_embs: | |
xb.append(emb) | |
ids.append(str(id)) | |
xb = np.stack(xb) | |
################## ANN indexing ################## | |
nlist = 262144 # 2^18 | |
# nprobes = [2 ** i for i in range(int(np.log(262144) / np.log(2)))] | |
nprobe = 512 # 128 * 2^2 | |
## fixed parameters | |
d = 768 | |
buffer_size = 50000 | |
batch_size = 1 | |
qtype_str = 'QT_8bit_uniform' | |
index_name = f'nq-{qtype_str}-ivf{nlist}.index' | |
with open(index_name.replace('index', 'txt'), 'w') as f: | |
for line in ids: | |
f.write(line + '\n') | |
if not os.path.exists(index_name): | |
quantizer = faiss.IndexFlatIP(d) | |
qtype = getattr(faiss.ScalarQuantizer, qtype_str) | |
index = faiss.IndexIVFScalarQuantizer(quantizer, d, nlist, qtype, faiss.METRIC_INNER_PRODUCT) | |
logger.info(f'#GPUs: {faiss.get_num_gpus()}') # Please use GPUs, since it is really large (21M * 768D/60GB embeddings) | |
index = faiss.index_cpu_to_all_gpus(index) | |
logger.info('>>> Doing training') | |
index.train(xb) | |
index.add(xb) | |
index = faiss.index_gpu_to_cpu(index) | |
index.nprobe = nprobe | |
faiss.write_index(index, index_name) | |
logger.info('>>> Loading the trained index for checking') | |
index = faiss.read_index(index_name) | |
assert index.nprobe == nprobe |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment