Skip to content

Instantly share code, notes, and snippets.

@NohTow
Created May 15, 2025 11:29
Show Gist options
  • Save NohTow/248e2695c73bf4b47cf3fb3c96b80ddc to your computer and use it in GitHub Desktop.
Save NohTow/248e2695c73bf4b47cf3fb3c96b80ddc to your computer and use it in GitHub Desktop.
BEIR Plaid PyLate boilerplate
"""Evaluation script for the SciFact dataset using the Beir library."""
from __future__ import annotations
import argparse
import os
import srsly
from pylate import evaluation, indexes, models, retrieve
if __name__ == "__main__":
query_len = {
"quora": 32,
"climate-fever": 64,
"nq": 32,
"msmarco": 32,
"hotpotqa": 32,
"nfcorpus": 32,
"scifact": 48,
"trec-covid": 48,
"fiqa": 32,
"arguana": 64,
"scidocs": 48,
"dbpedia-entity": 32,
"webis-touche2020": 32,
"fever": 32,
"cqadupstack/android": 32,
"cqadupstack/english": 32,
"cqadupstack/gaming": 32,
"cqadupstack/gis": 32,
"cqadupstack/mathematica": 32,
"cqadupstack/physics": 32,
"cqadupstack/programmers": 32,
"cqadupstack/stats": 32,
"cqadupstack/tex": 32,
"cqadupstack/unix": 32,
"cqadupstack/webmasters": 32,
"cqadupstack/wordpress": 32,
}
# Parse dataset_name from command line arguments
parser = argparse.ArgumentParser(description="Dataset name")
parser.add_argument(
"--dataset_name",
type=str,
default="nfcorpus",
help="Name of the dataset to evaluate on (default: 'fiqa')",
)
args = parser.parse_args()
dataset_name = args.dataset_name
model_name = "lightonai/GTE-ModernColBERT-v1"
model = models.ColBERT(
model_name_or_path=model_name,
document_length=300,
query_length=query_len.get(dataset_name),
)
if "cqadupstack" in dataset_name:
# Download dataset if not already downloaded
from beir import util
data_path = util.download_and_unzip(
url="https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/cqadupstack.zip",
out_dir="./evaluation_datasets/",
)
documents, queries, qrels = evaluation.load_custom_dataset(
f"evaluation_datasets/{dataset_name}",
split="test",
)
dataset_name = dataset_name.replace("/", "_")
else:
documents, queries, qrels = evaluation.load_beir(
dataset_name=dataset_name,
split="dev" if "msmarco" in dataset_name else "test",
)
total = 0
for key, values in qrels.items():
total += len(values)
index = indexes.PLAID(
override=True,
nbits=4,
index_name=f"{dataset_name}_{model_name.split('/')[-1]}_4bits",
)
retriever = retrieve.ColBERT(index=index)
documents_embeddings = model.encode(
sentences=[document["text"] for document in documents],
batch_size=2000,
is_query=False,
show_progress_bar=True,
)
index.add_documents(
documents_ids=[document["id"] for document in documents],
documents_embeddings=documents_embeddings,
)
queries_embeddings = model.encode(
sentences=list(queries.values()),
is_query=True,
show_progress_bar=True,
batch_size=32,
)
scores = retriever.retrieve(queries_embeddings=queries_embeddings, k=20)
for (query_id, query), query_scores in zip(queries.items(), scores):
for score in query_scores:
if score["id"] == query_id:
# Remove the query_id from the score
query_scores.remove(score)
evaluation_scores = evaluation.evaluate(
scores=scores,
qrels=qrels,
queries=list(queries.values()),
# queries=queries,
metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"],
)
print(evaluation_scores)
# Save the evaluation scores to a json file using srsly
output_dir = f"final_scores_{model_name.split('/')[-1]}"
os.makedirs(output_dir, exist_ok=True)
srsly.write_json(
os.path.join(output_dir, f"{dataset_name}_evaluation_scores.json"),
evaluation_scores,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment