Skip to content

Instantly share code, notes, and snippets.

@kwang2049
Last active February 28, 2022 00:34
Show Gist options
  • Save kwang2049/d23550604059ed1576ac6cffb7e09fb2 to your computer and use it in GitHub Desktop.
Save kwang2049/d23550604059ed1576ac6cffb7e09fb2 to your computer and use it in GitHub Desktop.
Example script that shows how to index the DPR single-nq (https://github.com/facebookresearch/DPR) embeddings with Faiss IndexIVFScalarQuantizer index.
import pickle
import os
import json
import faiss
import tqdm
import numpy as np
import pytrec_eval
import time
from typing import List, Tuple
from collections import defaultdict
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO
)
################## download data (embeddings) ##################
embedding_urls = [
"https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single/nq/wiki_passages_{}".format(i)
for i in range(50)
]
for url in embedding_urls:
os.system(f'wget {url}')
################## load pre-computed data ##################
logger.info('>>> Loading document embeddings')
xb = []
ids = []
findex_base = './wikipedia_passages_<i>.pkl'
findexes = [findex_base.replace('<i>', str(i)) for i in range(50)] # It should be 50!!!
for findex in tqdm.tqdm(findexes):
with open(findex, 'rb') as f:
context_embs: List[Tuple[int, np.ndarray]] = pickle.load(f)
for id, emb in context_embs:
xb.append(emb)
ids.append(str(id))
xb = np.stack(xb)
################## ANN indexing ##################
nlist = 262144 # 2^18
# nprobes = [2 ** i for i in range(int(np.log(262144) / np.log(2)))]
nprobe = 512 # 128 * 2^2
## fixed parameters
d = 768
buffer_size = 50000
batch_size = 1
qtype_str = 'QT_8bit_uniform'
index_name = f'nq-{qtype_str}-ivf{nlist}.index'
with open(index_name.replace('index', 'txt'), 'w') as f:
for line in ids:
f.write(line + '\n')
if not os.path.exists(index_name):
quantizer = faiss.IndexFlatIP(d)
qtype = getattr(faiss.ScalarQuantizer, qtype_str)
index = faiss.IndexIVFScalarQuantizer(quantizer, d, nlist, qtype, faiss.METRIC_INNER_PRODUCT)
logger.info(f'#GPUs: {faiss.get_num_gpus()}') # Please use GPUs, since it is really large (21M * 768D/60GB embeddings)
index = faiss.index_cpu_to_all_gpus(index)
logger.info('>>> Doing training')
index.train(xb)
index.add(xb)
index = faiss.index_gpu_to_cpu(index)
index.nprobe = nprobe
faiss.write_index(index, index_name)
logger.info('>>> Loading the trained index for checking')
index = faiss.read_index(index_name)
assert index.nprobe == nprobe
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment