Created
July 9, 2025 12:59
-
-
Save tomaarsen/f80162fc4dcce61543adea9766d68556 to your computer and use it in GitHub Desktop.
Boilerplate to train a Sparse Embedding model (SPLADE architecture) using Sentence Transformers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| from datasets import load_dataset | |
| from sentence_transformers import ( | |
| SparseEncoder, | |
| SparseEncoderModelCardData, | |
| SparseEncoderTrainer, | |
| SparseEncoderTrainingArguments, | |
| ) | |
| from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator | |
| from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss | |
| from sentence_transformers.training_args import BatchSamplers | |
| logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) | |
| # 1. Load a model to finetune with 2. (Optional) model card data | |
| model = SparseEncoder( | |
| "distilbert/distilbert-base-uncased", | |
| model_card_data=SparseEncoderModelCardData( | |
| language="en", | |
| license="apache-2.0", | |
| model_name="DistilBERT base trained on Natural-Questions tuples", | |
| ) | |
| ) | |
| # 3. Load a dataset to finetune on | |
| full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000)) | |
| dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12) | |
| train_dataset = dataset_dict["train"] | |
| eval_dataset = dataset_dict["test"] | |
| # 4. Define a loss function | |
| loss = SpladeLoss( | |
| model=model, | |
| loss=SparseMultipleNegativesRankingLoss(model=model), | |
| query_regularizer_weight=5e-5, | |
| document_regularizer_weight=3e-5, | |
| ) | |
| # 5. (Optional) Specify training arguments | |
| run_name = "splade-distilbert-base-uncased-nq" | |
| args = SparseEncoderTrainingArguments( | |
| # Required parameter: | |
| output_dir=f"models/{run_name}", | |
| # Optional training parameters: | |
| num_train_epochs=1, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=16, | |
| learning_rate=2e-5, | |
| warmup_ratio=0.1, | |
| fp16=True, # Set to False if you get an error that your GPU can't run on FP16 | |
| bf16=False, # Set to True if you have a GPU that supports BF16 | |
| batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch | |
| # Optional tracking/debugging parameters: | |
| eval_strategy="steps", | |
| eval_steps=1000, | |
| save_strategy="steps", | |
| save_steps=1000, | |
| save_total_limit=2, | |
| logging_steps=200, | |
| run_name=run_name, # Will be used in W&B if `wandb` is installed | |
| ) | |
| # 6. (Optional) Create an evaluator & evaluate the base model | |
| dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16) | |
| # 7. Create a trainer & train | |
| trainer = SparseEncoderTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| loss=loss, | |
| evaluator=dev_evaluator, | |
| ) | |
| trainer.train() | |
| # 8. Evaluate the model performance again after training | |
| dev_evaluator(model) | |
| # 9. Save the trained model | |
| model.save_pretrained(f"models/{run_name}/final") | |
| # 10. (Optional) Push it to the Hugging Face Hub | |
| model.push_to_hub(run_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment