Skip to content

Instantly share code, notes, and snippets.

@muellerzr
Last active October 16, 2024 18:51
Show Gist options
  • Save muellerzr/a8f9f98c408059ecc09091b5d14b1710 to your computer and use it in GitHub Desktop.
Save muellerzr/a8f9f98c408059ecc09091b5d14b1710 to your computer and use it in GitHub Desktop.
import torch.nn as nn
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
set_seed,
)
from functools import partial
set_seed(42)
def ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch):
logits = logits["logits"]
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if num_items_in_batch is not None:
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum")
loss = loss / num_items_in_batch
else:
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100)
print(loss)
return loss
# Constants
model_name = "distilgpt2"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
# Load dataset=
dataset = load_dataset(dataset_name, dataset_config, split="train[:500]")
dataset = dataset.train_test_split(test_size=0.2)
# Tokenize the dataset
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_function(examples):
return tokenizer(examples["text"])
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
block_size = 128
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of block_size.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
# And apply
tokenized_dataset = tokenized_dataset.map(group_texts, batched=True)
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = AutoModelForCausalLM.from_pretrained(model_name)
training_args = TrainingArguments(
output_dir="results/causal_language_modeling", # Where weights are stored
learning_rate=2e-5, # The learning rate during training
per_device_train_batch_size=2, # Number of samples per batch during training
gradient_accumulation_steps=4,
per_device_eval_batch_size=8, # Number of samples per batch during evaluation
max_steps=18, # How many iterations through the dataloaders should be done
disable_tqdm=True,
)
loss_fn = partial(ForCausalLMLoss, vocab_size=model.config.vocab_size)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
data_collator=data_collator,
compute_loss=loss_fn,
)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment