Created
May 15, 2025 11:29
-
-
Save NohTow/248e2695c73bf4b47cf3fb3c96b80ddc to your computer and use it in GitHub Desktop.
BEIR Plaid PyLate boilerplate
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Evaluation script for the SciFact dataset using the Beir library.""" | |
from __future__ import annotations | |
import argparse | |
import os | |
import srsly | |
from pylate import evaluation, indexes, models, retrieve | |
if __name__ == "__main__": | |
query_len = { | |
"quora": 32, | |
"climate-fever": 64, | |
"nq": 32, | |
"msmarco": 32, | |
"hotpotqa": 32, | |
"nfcorpus": 32, | |
"scifact": 48, | |
"trec-covid": 48, | |
"fiqa": 32, | |
"arguana": 64, | |
"scidocs": 48, | |
"dbpedia-entity": 32, | |
"webis-touche2020": 32, | |
"fever": 32, | |
"cqadupstack/android": 32, | |
"cqadupstack/english": 32, | |
"cqadupstack/gaming": 32, | |
"cqadupstack/gis": 32, | |
"cqadupstack/mathematica": 32, | |
"cqadupstack/physics": 32, | |
"cqadupstack/programmers": 32, | |
"cqadupstack/stats": 32, | |
"cqadupstack/tex": 32, | |
"cqadupstack/unix": 32, | |
"cqadupstack/webmasters": 32, | |
"cqadupstack/wordpress": 32, | |
} | |
# Parse dataset_name from command line arguments | |
parser = argparse.ArgumentParser(description="Dataset name") | |
parser.add_argument( | |
"--dataset_name", | |
type=str, | |
default="nfcorpus", | |
help="Name of the dataset to evaluate on (default: 'fiqa')", | |
) | |
args = parser.parse_args() | |
dataset_name = args.dataset_name | |
model_name = "lightonai/GTE-ModernColBERT-v1" | |
model = models.ColBERT( | |
model_name_or_path=model_name, | |
document_length=300, | |
query_length=query_len.get(dataset_name), | |
) | |
if "cqadupstack" in dataset_name: | |
# Download dataset if not already downloaded | |
from beir import util | |
data_path = util.download_and_unzip( | |
url="https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/cqadupstack.zip", | |
out_dir="./evaluation_datasets/", | |
) | |
documents, queries, qrels = evaluation.load_custom_dataset( | |
f"evaluation_datasets/{dataset_name}", | |
split="test", | |
) | |
dataset_name = dataset_name.replace("/", "_") | |
else: | |
documents, queries, qrels = evaluation.load_beir( | |
dataset_name=dataset_name, | |
split="dev" if "msmarco" in dataset_name else "test", | |
) | |
total = 0 | |
for key, values in qrels.items(): | |
total += len(values) | |
index = indexes.PLAID( | |
override=True, | |
nbits=4, | |
index_name=f"{dataset_name}_{model_name.split('/')[-1]}_4bits", | |
) | |
retriever = retrieve.ColBERT(index=index) | |
documents_embeddings = model.encode( | |
sentences=[document["text"] for document in documents], | |
batch_size=2000, | |
is_query=False, | |
show_progress_bar=True, | |
) | |
index.add_documents( | |
documents_ids=[document["id"] for document in documents], | |
documents_embeddings=documents_embeddings, | |
) | |
queries_embeddings = model.encode( | |
sentences=list(queries.values()), | |
is_query=True, | |
show_progress_bar=True, | |
batch_size=32, | |
) | |
scores = retriever.retrieve(queries_embeddings=queries_embeddings, k=20) | |
for (query_id, query), query_scores in zip(queries.items(), scores): | |
for score in query_scores: | |
if score["id"] == query_id: | |
# Remove the query_id from the score | |
query_scores.remove(score) | |
evaluation_scores = evaluation.evaluate( | |
scores=scores, | |
qrels=qrels, | |
queries=list(queries.values()), | |
# queries=queries, | |
metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"], | |
) | |
print(evaluation_scores) | |
# Save the evaluation scores to a json file using srsly | |
output_dir = f"final_scores_{model_name.split('/')[-1]}" | |
os.makedirs(output_dir, exist_ok=True) | |
srsly.write_json( | |
os.path.join(output_dir, f"{dataset_name}_evaluation_scores.json"), | |
evaluation_scores, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment