Skip to content

Instantly share code, notes, and snippets.

@NohTow
Created September 11, 2025 15:52
Show Gist options
  • Save NohTow/2f7bc45422feca1df7bbac0b85f44630 to your computer and use it in GitHub Desktop.
Save NohTow/2f7bc45422feca1df7bbac0b85f44630 to your computer and use it in GitHub Desktop.
contrastive boilerplate msmarco v1.1
from __future__ import annotations
import torch
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from pylate import evaluation, losses, models, utils
# Define model parameters for contrastive training
model_name = "bert-base-uncased" # Choose the pre-trained model you want to use as base
batch_size = 32 # Larger batch size often improves results, but requires more memory
num_train_epochs = 1 # Adjust based on your requirements
# Set the run name for logging and output directory
run_name = "contrastive-bert-base-uncased"
output_dir = f"output/{run_name}"
# 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder.
model = models.ColBERT(model_name_or_path=model_name)
# Compiling the model makes the training faster
model = torch.compile(model)
# Load dataset
dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")
# For each row, use column "query" as column 1 and use "answer"[0] as column 2. If positive is a list of size 0, skip the row.
train_dataset = dataset.map(
lambda x: {"query": x["query"], "positive": x["answers"][0]}
if len(x["answers"]) > 0
else None
)
# drop any columns that are not "query" or "positive"
train_dataset = train_dataset.remove_columns(
[col for col in train_dataset.column_names if col not in ["query", "positive"]]
)
print(train_dataset[0])
# Define the loss function
train_loss = losses.Contrastive(model=model)
dev_evaluator = evaluation.NanoBEIREvaluator()
# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps)
args = SentenceTransformerTrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
eval_strategy="steps",
eval_steps=250,
logging_steps=1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
run_name=run_name, # Will be used in W&B if `wandb` is installed
learning_rate=3e-6,
)
# Initialize the trainer for the contrastive training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=train_loss,
evaluator=dev_evaluator,
data_collator=utils.ColBERTCollator(model.tokenize),
)
# Start the training process
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment