Last active
October 16, 2024 18:51
-
-
Save muellerzr/a8f9f98c408059ecc09091b5d14b1710 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
from datasets import load_dataset | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
DataCollatorForLanguageModeling, | |
Trainer, | |
TrainingArguments, | |
set_seed, | |
) | |
from functools import partial | |
set_seed(42) | |
def ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch): | |
logits = logits["logits"] | |
# Upcast to float if we need to compute the loss to avoid potential precision issues | |
logits = logits.float() | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
shift_logits = shift_logits.view(-1, vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Enable model parallelism | |
shift_labels = shift_labels.to(shift_logits.device) | |
if num_items_in_batch is not None: | |
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum") | |
loss = loss / num_items_in_batch | |
else: | |
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100) | |
print(loss) | |
return loss | |
# Constants | |
model_name = "distilgpt2" | |
dataset_name = "wikitext" | |
dataset_config = "wikitext-2-raw-v1" | |
# Load dataset= | |
dataset = load_dataset(dataset_name, dataset_config, split="train[:500]") | |
dataset = dataset.train_test_split(test_size=0.2) | |
# Tokenize the dataset | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def tokenize_function(examples): | |
return tokenizer(examples["text"]) | |
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) | |
block_size = 128 | |
def group_texts(examples): | |
# Concatenate all texts. | |
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} | |
total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can | |
# customize this part to your needs. | |
if total_length >= block_size: | |
total_length = (total_length // block_size) * block_size | |
# Split by chunks of block_size. | |
result = { | |
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | |
for k, t in concatenated_examples.items() | |
} | |
result["labels"] = result["input_ids"].copy() | |
return result | |
# And apply | |
tokenized_dataset = tokenized_dataset.map(group_texts, batched=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
training_args = TrainingArguments( | |
output_dir="results/causal_language_modeling", # Where weights are stored | |
learning_rate=2e-5, # The learning rate during training | |
per_device_train_batch_size=2, # Number of samples per batch during training | |
gradient_accumulation_steps=4, | |
per_device_eval_batch_size=8, # Number of samples per batch during evaluation | |
max_steps=18, # How many iterations through the dataloaders should be done | |
disable_tqdm=True, | |
) | |
loss_fn = partial(ForCausalLMLoss, vocab_size=model.config.vocab_size) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset["train"], | |
eval_dataset=tokenized_dataset["test"], | |
data_collator=data_collator, | |
compute_loss=loss_fn, | |
) | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment