Skip to content

Instantly share code, notes, and snippets.

@jogonba2
Last active March 28, 2025 10:47
Show Gist options
  • Save jogonba2/9bee8bb154a292b24850f1483daa6b71 to your computer and use it in GitHub Desktop.
Save jogonba2/9bee8bb154a292b24850f1483daa6b71 to your computer and use it in GitHub Desktop.
GRPO from deepseek r1 for summarization using encoder-decoder models
import gc
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
import evaluate
import nltk
import torch
from datasets import Dataset, DatasetDict, load_dataset
from torch import FloatTensor, LongTensor
from torch.nn.utils import clip_grad_norm_
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
PreTrainedModel,
PreTrainedTokenizer,
)
import wandb
rouge_eval = evaluate.load("rouge")
@dataclass
class TrainingArguments:
epochs: int
batch_size: int
learning_rate: float
update_old_after: int
group_size: int
logging_steps: int
max_new_tokens: int
max_document_length: int
max_summary_length: int
grpo_epsilon: float
grpo_beta: float
gradient_max_norm: float
save_steps: int
save_dir: str
@dataclass
class BatchRewards:
rewards: FloatTensor
@dataclass
class GRPOOutput:
loss: FloatTensor
reward: FloatTensor
kl: FloatTensor
def load_model(model_name: str) -> PreTrainedModel:
"""
Loads a pre-trained encoder-decoder model.
Args:
model_name (str): The name or path of the pre-trained model to load.
Returns:
PreTrainedModel: The loaded pre-trained model
"""
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
)
model = model.to("cuda")
return model
def load_tokenizer(model_name: str) -> PreTrainedTokenizer:
"""
Load a pre-trained tokenizer.
Args:
model_name (str): The name or path of the pre-trained model.
Returns:
PreTrainedTokenizer: The tokenizer associated with the specified model.
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer
def load_cnndm() -> DatasetDict:
"""
Load and preprocess the CNN/DailyMail dataset.
Returns:
DatasetDict: A dictionary containing the preprocessed dataset splits.
"""
dataset = load_dataset("abisee/cnn_dailymail", "3.0.0")
for split in dataset.column_names:
dataset[split] = dataset[split].remove_columns(["id"])
dataset[split] = dataset[split].rename_columns(
{"article": "document", "highlights": "summary"}
)
return dataset
def tokenize_dataset(
dataset: DatasetDict,
tokenizer: PreTrainedTokenizer,
max_document_length: int,
max_summary_length: int,
) -> DatasetDict:
"""
Tokenizes the documents and summaries in the dataset.
Args:
dataset (DatasetDict): The dataset containing documents and summaries to be tokenized.
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the text.
max_document_length (int): The maximum length for tokenized documents.
max_summary_length (int): The maximum length for tokenized summaries.
Returns:
DatasetDict: The tokenized dataset with documents and summaries replaced by their tokenized versions.
"""
def tokenize_function(example):
model_inputs = tokenizer(
example["document"],
max_length=max_document_length,
truncation=True,
)
labels = tokenizer(
example["summary"],
max_length=max_summary_length,
truncation=True,
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
return dataset.map(
tokenize_function, batched=True, remove_columns=["document", "summary"]
)
def rouge_reward(predictions: list[str], references: list[str]) -> float:
"""
Calculate the average ROUGE (1, 2, Lsum) scores for a set of predictions and references.
Args:
predictions (list[str]): A list of predicted text strings.
references (list[str]): A list of reference text strings.
Returns:
float: The average ROUGE score (ROUGE-1, ROUGE-2, and ROUGE-Lsum).
"""
scores = rouge_eval.compute(predictions=predictions, references=references)
return (scores["rouge1"] + scores["rouge2"] + scores["rougeLsum"]) / 3.0
def postprocess_text(
preds: list[str], labels: list[str]
) -> Tuple[list[str], list[str]]:
"""
Post-processes the predicted and label texts,
formatting them for ROUGE-L summarization evaluation.
Args:
preds (list[str]): List of predicted text strings.
labels (list[str]): List of label text strings.
Returns:
Tuple[list[str], list[str]]: A tuple containing two lists:
- The first list contains the post-processed predicted texts.
- The second list contains the post-processed label texts.
"""
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = [
"\n".join(nltk.sent_tokenize(pred.replace("<n>", " ")))
for pred in preds
]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
def compute_rewards(
token_ids: LongTensor, labels: LongTensor, tokenizer: PreTrainedTokenizer
) -> BatchRewards:
"""
Compute rewards based on the ROUGE avg score between generated completions and reference summaries.
Args:
token_ids (LongTensor): Tensor containing token IDs of the generated completions.
labels (LongTensor): Tensor containing token IDs of the reference summaries.
tokenizer (PreTrainedTokenizer): Tokenizer used to decode the token IDs.
Returns:
BatchRewards: A tensor containing the computed rewards for each completion.
"""
labels[labels == -100] = tokenizer.pad_token_id
completions = tokenizer.batch_decode(token_ids, skip_special_tokens=True)
summaries = tokenizer.batch_decode(labels, skip_special_tokens=True)
completions, summaries = postprocess_text(completions, summaries)
rewards = torch.zeros(token_ids.shape[0], device=token_ids.device)
for idx, (completion, summary) in enumerate(zip(completions, summaries)):
rouge_score = rouge_reward(
predictions=[completion], references=[summary]
)
rewards[idx] = rouge_score
return BatchRewards(rewards)
def selective_log_softmax(
logits: FloatTensor, index: LongTensor
) -> FloatTensor:
"""
Computes the log softmax of the input logits selectively based on the provided indices.
This function performs the same operation as applying `log_softmax` on the logits tensor
along the last dimension and then gathering the results based on the provided indices.
However, it processes the logits row by row to save memory by leveraging PyTorch internals.
Taken from https://www.tylerromero.com/posts/2025-02-selective-log-softmax/
Args:
logits (FloatTensor): A tensor of shape (batch_size, num_classes) containing the raw
logits for each class.
index (LongTensor): A tensor of shape (batch_size, num_indices) containing the indices
of the classes for which to compute the log softmax.
Returns:
FloatTensor: A tensor of shape (batch_size, num_indices) containing the log softmax
values for the specified indices.
"""
token_logprobs = []
for logits_row, index_row in zip(logits, index):
logprobs_row = logits_row.log_softmax(dim=-1)
token_logprobs_row = torch.gather(
logprobs_row, dim=-1, index=index_row.unsqueeze(-1)
).squeeze(-1)
token_logprobs.append(token_logprobs_row)
return torch.stack(token_logprobs)
def gather_token_scores(
logits: FloatTensor, generated_ids: LongTensor
) -> FloatTensor:
"""
Gathers token scores from logits based on generated token IDs.
Args:
logits (FloatTensor): The logits output from the model. It can be a tuple of tensors or a single tensor.
generated_ids (LongTensor): The IDs of the generated tokens.
Returns:
FloatTensor: The token scores after applying a selective log softmax on the logits.
"""
if isinstance(logits, tuple):
# Stack the logits (batch_size*group_size, output_length, vocab)
logits = torch.stack(logits, axis=0).permute((1, 0, 2))
# Logsoftmax the logits
token_scores = selective_log_softmax(logits, generated_ids)
return token_scores
def compute_token_scores(
model: PreTrainedModel,
encoder_input_ids: LongTensor,
encoder_attention_mask: LongTensor,
decoder_input_ids: LongTensor,
decoder_attention_mask: LongTensor,
batch_size: int,
group_size: int,
) -> FloatTensor:
"""
Computes token scores for a given batch of input sequences using a pre-trained model.
Args:
model (PreTrainedModel): The pre-trained model to use for generating logits.
encoder_input_ids (LongTensor): Tensor containing input IDs for the encoder.
encoder_attention_mask (LongTensor): Tensor containing attention masks for the encoder inputs.
decoder_input_ids (LongTensor): Tensor containing input IDs for the decoder.
decoder_attention_mask (LongTensor): Tensor containing attention masks for the decoder inputs.
batch_size (int): The size of the batch.
group_size (int): The size of the group.
Returns:
FloatTensor: A tensor containing the computed token scores, reshaped to (batch_size, group_size, -1).
"""
logits = model(
input_ids=encoder_input_ids,
attention_mask=encoder_attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
).logits
scores = gather_token_scores(logits[:, :-1], decoder_input_ids[:, 1:])
scores = scores.view(batch_size, group_size, -1)
del logits
torch.cuda.empty_cache()
return scores
def grpo(
generated_ids: LongTensor,
old_scores: FloatTensor,
current_scores: FloatTensor,
reference_scores: FloatTensor,
labels: LongTensor,
tokenizer: PreTrainedTokenizer,
epsilon: float,
beta: float,
) -> GRPOOutput:
"""
Compute the loss of Group Relative Policy Optimization (GRPO) on the given inputs.
Args:
generated_ids (LongTensor): Tensor of generated token IDs.
old_scores (FloatTensor): Tensor of old policy scores.
current_scores (FloatTensor): Tensor of current policy scores.
reference_scores (FloatTensor): Tensor of reference policy scores.
truths (LongTensor): Tensor of ground truth token IDs.
tokenizer (PreTrainedTokenizer): Tokenizer used for encoding/decoding.
epsilon (float): Clipping parameter for policy ratios.
beta (float): Weighting factor for the Kullback-Leibler divergence term.
Returns:
GRPOOutput: A dataclass containing the mean loss, rewards and KL divergences.
"""
losses = torch.zeros(generated_ids.shape[0])
rewards = torch.zeros(generated_ids.shape[0])
kls = torch.zeros(generated_ids.shape[0])
for idx, (
group_ids,
group_labels,
group_old_scores,
group_current_scores,
group_reference_scores,
) in enumerate(
zip(generated_ids, labels, old_scores, current_scores, reference_scores)
):
# Compute advantages
group_rewards = compute_rewards(group_ids, group_labels, tokenizer)
mean = group_rewards.rewards.mean()
centered = group_rewards.rewards - mean
std = group_rewards.rewards.std()
if std < 1e-8:
advantages = torch.zeros_like(centered)
else:
advantages = centered / (std + 1e-8)
# Store the mean of each rewards for the group
rewards[idx] = group_rewards.rewards.mean()
# Compute the ratios
ratios = torch.exp(group_current_scores - group_old_scores)
# Compute the clipped ratios
clipped_ratios = torch.clamp(
ratios, min=1.0 - epsilon, max=1.0 + epsilon
)
# Compute kullback-leibler divergence between reference and current policy
kl = (
torch.exp(group_reference_scores - group_current_scores)
- (group_reference_scores - group_current_scores)
- 1
)
kls[idx] = kl.mean()
# Compute mean loss of the group
completion_mask = group_ids[:, 1:] != tokenizer.pad_token_id
loss = (
torch.min(
ratios * advantages.unsqueeze(-1),
clipped_ratios * advantages.unsqueeze(-1),
)
- beta * kl
)
loss = -(loss * completion_mask).sum() / completion_mask.sum()
losses[idx] = loss
return GRPOOutput(
loss=losses.mean(),
reward=rewards.mean(),
kl=kls.mean(),
)
def train(
dataset: Dataset,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
training_args: TrainingArguments,
) -> None:
"""
Train a language model using the GRPO (Group Relative Policy Optimization) objective.
Args:
dataset (Dataset): The dataset containing training data.
model (PreTrainedModel): The model to be trained.
tokenizer (PreTrainedTokenizer): The tokenizer used for encoding the data.
training_args (TrainingArguments): The training arguments containing hyperparameters and configurations.
"""
# Prepare the dataloader
train_dataloader = DataLoader(
dataset["train"],
collate_fn=DataCollatorForSeq2Seq(tokenizer),
batch_size=training_args.batch_size,
)
# Prepare policies
reference_model = deepcopy(model)
old_model = deepcopy(model)
reference_model.eval()
old_model.eval()
model.train()
# Prepare optimizer and lr scheduler
optimizer = AdamW(
model.parameters(),
lr=training_args.learning_rate,
)
scheduler = LinearLR(
optimizer,
start_factor=1,
end_factor=0.1,
total_iters=training_args.epochs * len(train_dataloader),
)
# Prepare the metrics
running_metrics = {
"loss": 0.0,
"reward": 0.0,
"completion_length": 0.0,
"kl": 0.0,
}
# Let's train
training_step = 0
best_reward = 0.0
for _ in range(training_args.epochs):
# Update the old policy
old_model.load_state_dict(model.state_dict(), strict=False)
for batch in tqdm(train_dataloader, total=len(train_dataloader)):
# Prepare the batch data
input_ids = batch["input_ids"].to(model.device)
attention_mask = batch["attention_mask"].to(model.device)
labels = batch["labels"].to(model.device)
effective_batch_size = input_ids.shape[0]
# Generate ids with the old policy
generated_ids = old_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=training_args.max_new_tokens,
do_sample=True,
num_beams=training_args.group_size,
num_return_sequences=training_args.group_size,
)
# Prepare attention mask for computing current
# and reference logits on the generated ids
decoder_attention_mask = generated_ids != tokenizer.pad_token_id
# Interleave input_ids and attention_mask to have
# the same shape than the generated completions
repeated_input_ids = input_ids.repeat_interleave(
repeats=training_args.group_size, dim=0
)
repeated_attention_mask = attention_mask.repeat_interleave(
repeats=training_args.group_size, dim=0
)
# Compute the sequence scores of the old policy
with torch.inference_mode(), torch.autocast(
"cuda", dtype=torch.bfloat16
):
old_scores = compute_token_scores(
old_model,
encoder_input_ids=repeated_input_ids,
encoder_attention_mask=repeated_attention_mask,
decoder_input_ids=generated_ids,
decoder_attention_mask=decoder_attention_mask,
batch_size=effective_batch_size,
group_size=training_args.group_size,
)
# Compute the sequence scores of the current policy
with torch.autocast("cuda", dtype=torch.bfloat16):
model.eval()
current_scores = compute_token_scores(
model,
encoder_input_ids=repeated_input_ids,
encoder_attention_mask=repeated_attention_mask,
decoder_input_ids=generated_ids,
decoder_attention_mask=decoder_attention_mask,
batch_size=effective_batch_size,
group_size=training_args.group_size,
)
model.train()
# Compute the sequence scores of the reference model
with torch.inference_mode(), torch.autocast(
"cuda", dtype=torch.bfloat16
):
reference_scores = compute_token_scores(
reference_model,
encoder_input_ids=repeated_input_ids,
encoder_attention_mask=repeated_attention_mask,
decoder_input_ids=generated_ids,
decoder_attention_mask=decoder_attention_mask,
batch_size=effective_batch_size,
group_size=training_args.group_size,
)
# Group the generated ids (batch_size, group_size, output_length)
generated_ids = generated_ids.view(
effective_batch_size, training_args.group_size, -1
)
# Repeat the labels and group (batch_size, group_size)
labels = labels.repeat_interleave(
repeats=training_args.group_size, dim=0
).view(effective_batch_size, training_args.group_size, -1)
# Compute GRPO objective
with torch.autocast("cuda", dtype=torch.bfloat16):
grpo_output = grpo(
generated_ids,
old_scores,
current_scores,
reference_scores,
labels,
tokenizer,
training_args.grpo_epsilon,
training_args.grpo_beta,
)
# Update the current policy
grpo_output.loss.backward()
clip_grad_norm_(
model.parameters(),
training_args.gradient_max_norm,
)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# Update old policy periodically
if (training_step + 1) % training_args.update_old_after == 0:
old_model.load_state_dict(model.state_dict(), strict=False)
torch.cuda.empty_cache()
# Update log metrics
batch_metrics = {
"loss": grpo_output.loss.item(),
"reward": grpo_output.reward.item(),
"kl": grpo_output.kl.item(),
"completion_length": decoder_attention_mask.sum(-1)
.float()
.mean()
.item(),
}
running_metrics = {
key: running_metrics[key] + batch_metrics.get(key, 0)
for key in running_metrics
}
# And report them periodically
if (training_step + 1) % training_args.logging_steps == 0:
wandb.log(
{
**{
key: val / (training_step + 1)
for key, val in running_metrics.items()
},
**{"lr": scheduler.get_last_lr()[0]},
}
)
# Save the model each periodically
if (training_step + 1) % training_args.save_steps == 0:
last_reward = running_metrics["loss"] / (training_step + 1)
if last_reward > best_reward:
model.save_pretrained(f"{training_args.save_dir}")
best_reward = last_reward
print(
"Saving model with reward:",
best_reward,
f"step: {training_step+1}",
)
else:
print(
f"Model not saved because didn't improve the reward at step {training_step+1}"
)
# Free GPU memory at the end
del (
generated_ids,
old_scores,
input_ids,
attention_mask,
repeated_input_ids,
repeated_attention_mask,
current_scores,
reference_scores,
grpo_output,
labels,
)
torch.cuda.empty_cache()
gc.collect()
training_step += 1
def main():
"""
Main function for training an encoder-decoder model with
GRPO to optimize ROUGE on the CNN/DailyMail dataset.
"""
# Instantiate current policy and reference model
model_name = "google/pegasus-cnn_dailymail"
model = load_model(model_name)
tokenizer = load_tokenizer(model_name)
# Define training arguments
training_args = TrainingArguments(
epochs=1,
batch_size=8,
learning_rate=1e-5,
update_old_after=1000,
group_size=5,
logging_steps=10,
max_new_tokens=128,
max_document_length=512,
max_summary_length=128,
grpo_epsilon=0.1,
grpo_beta=0.04,
gradient_max_norm=0.2,
save_steps=100,
save_dir="./grpo_pegasus-cnn-dailymail",
)
# Load and format the dataset
dataset = load_cnndm()
dataset = tokenize_dataset(
dataset,
tokenizer,
training_args.max_document_length,
training_args.max_summary_length,
)
# Initialize wandb
wandb.login()
wandb.init(
project="GRPO-Summarization",
config={
"model": model_name,
"dataset": "abisee/cnn_dailymail",
"training_args": training_args.__dict__,
},
)
# Let's train!
train(dataset, model, tokenizer, training_args)
# Save the model and finish logging
model.save_pretrained(f"grpo_{model_name.replace('/', '_')}_cnn_dailymail")
wandb.finish()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment