Created
September 11, 2025 15:52
-
-
Save NohTow/2f7bc45422feca1df7bbac0b85f44630 to your computer and use it in GitHub Desktop.
contrastive boilerplate msmarco v1.1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import torch | |
from datasets import load_dataset | |
from sentence_transformers import ( | |
SentenceTransformerTrainer, | |
SentenceTransformerTrainingArguments, | |
) | |
from pylate import evaluation, losses, models, utils | |
# Define model parameters for contrastive training | |
model_name = "bert-base-uncased" # Choose the pre-trained model you want to use as base | |
batch_size = 32 # Larger batch size often improves results, but requires more memory | |
num_train_epochs = 1 # Adjust based on your requirements | |
# Set the run name for logging and output directory | |
run_name = "contrastive-bert-base-uncased" | |
output_dir = f"output/{run_name}" | |
# 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder. | |
model = models.ColBERT(model_name_or_path=model_name) | |
# Compiling the model makes the training faster | |
model = torch.compile(model) | |
# Load dataset | |
dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train") | |
# For each row, use column "query" as column 1 and use "answer"[0] as column 2. If positive is a list of size 0, skip the row. | |
train_dataset = dataset.map( | |
lambda x: {"query": x["query"], "positive": x["answers"][0]} | |
if len(x["answers"]) > 0 | |
else None | |
) | |
# drop any columns that are not "query" or "positive" | |
train_dataset = train_dataset.remove_columns( | |
[col for col in train_dataset.column_names if col not in ["query", "positive"]] | |
) | |
print(train_dataset[0]) | |
# Define the loss function | |
train_loss = losses.Contrastive(model=model) | |
dev_evaluator = evaluation.NanoBEIREvaluator() | |
# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps) | |
args = SentenceTransformerTrainingArguments( | |
output_dir=output_dir, | |
num_train_epochs=num_train_epochs, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
eval_strategy="steps", | |
eval_steps=250, | |
logging_steps=1, | |
fp16=True, # Set to False if you get an error that your GPU can't run on FP16 | |
bf16=False, # Set to True if you have a GPU that supports BF16 | |
run_name=run_name, # Will be used in W&B if `wandb` is installed | |
learning_rate=3e-6, | |
) | |
# Initialize the trainer for the contrastive training | |
trainer = SentenceTransformerTrainer( | |
model=model, | |
args=args, | |
train_dataset=train_dataset, | |
loss=train_loss, | |
evaluator=dev_evaluator, | |
data_collator=utils.ColBERTCollator(model.tokenize), | |
) | |
# Start the training process | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment