Skip to content

Instantly share code, notes, and snippets.

@jogonba2
Last active March 7, 2025 18:04
Show Gist options
  • Save jogonba2/468b9904a3a03727ffccfca6a423d7b1 to your computer and use it in GitHub Desktop.
Save jogonba2/468b9904a3a03727ffccfca6a423d7b1 to your computer and use it in GitHub Desktop.
Group Relative Policy Optimization from DeepSeek R1-zero just for fun
import gc
import re
from copy import deepcopy
from dataclasses import dataclass
import torch
from datasets import Dataset, DatasetDict, load_dataset
from math_verify import parse, verify
from peft import (
LoraConfig,
PeftModel,
get_peft_model,
prepare_model_for_kbit_training,
)
from torch import FloatTensor, LongTensor
from torch.nn.utils import clip_grad_norm_
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorWithPadding,
PreTrainedModel,
PreTrainedTokenizer,
)
import wandb
@dataclass
class TrainingArguments:
epochs: int
batch_size: int
learning_rate: float
update_old_after: int
group_size: int
logging_steps: int
max_new_tokens: int
prompt_max_length: int
temperature: float
grpo_epsilon: float
grpo_beta: float
gradient_max_norm: float
save_steps: int
save_dir: str
@dataclass
class BatchRewards:
format_rewards: FloatTensor
accuracy_rewards: FloatTensor
total_rewards: FloatTensor
@dataclass
class GRPOOutput:
loss: FloatTensor
format_reward: FloatTensor
accuracy_reward: FloatTensor
total_reward: FloatTensor
kl: FloatTensor
def load_model(model_name: str) -> PreTrainedModel:
"""
Loads a pre-trained language model with quantization and attention configs.
Args:
model_name (str): The name or path of the pre-trained model to load.
Returns:
PreTrainedModel: The loaded pre-trained model.
"""
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
attn_implementation="flash_attention_2",
)
model.gradient_checkpointing_enable()
return model
def get_lora_model(
model: PreTrainedModel, lora_config: LoraConfig
) -> PeftModel:
"""
Prepares and returns a LoRA model for training.
Args:
model (PreTrainedModel): The pre-trained model to be adapted with LoRA.
lora_config (LoraConfig): The configuration for the LoRA adaptation.
Returns:
PeftModel: The model adapted with LoRA, ready for training.
"""
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model
def load_tokenizer(model_name: str) -> PreTrainedTokenizer:
"""
Loads the tokenizer of a pre-trained model.
Args:
model_name (str): The name or path of the pre-trained model.
Returns:
PreTrainedTokenizer: The tokenizer associated with the specified model.
"""
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer
def normalize_number(x: str) -> str:
"""
Normalize a numeric string by removing commas and periods,
and stripping any leading or trailing whitespace.
Dirty trick to simplify the task for math_verify.
Args:
x (str): The numeric string to normalize.
Returns:
str: The normalized numeric string
"""
return x.strip().replace(",", "").replace(".", "")
def load_gsm8k() -> DatasetDict:
"""
Loads the GSM8K dataset and preprocesses the answers.
Returns:
DatasetDict: A dictionary containing the processed GSM8K dataset.
"""
dataset = load_dataset("openai/gsm8k", "main")
return dataset.map(
lambda answers: {
"answer": [
int(normalize_number(answer.split("####")[-1]))
for answer in answers
]
},
input_columns=["answer"],
batched=True,
)
def prompt_dataset(
dataset: DatasetDict, tokenizer: PreTrainedTokenizer, instruction: str
) -> DatasetDict:
"""
Applies a chat template to each question in the dataset
using the provided tokenizer and instruction.
Args:
dataset (DatasetDict): The dataset containing questions to be processed.
tokenizer (PreTrainedTokenizer): The tokenizer used to apply the chat template.
instruction (str): The instruction to be included in the system role of the chat template.
Returns:
DatasetDict: The dataset with the processed prompts.
"""
return dataset.map(
lambda questions: {
"prompt": tokenizer.apply_chat_template(
[
[
{"role": "system", "content": instruction},
{"role": "user", "content": question},
]
for question in questions
],
tokenize=False,
add_generation_prompt=True,
)
},
input_columns=["question"],
batched=True,
remove_columns=["question"],
)
def tokenize_dataset(
dataset: DatasetDict, tokenizer: PreTrainedTokenizer, max_prompt_length: int
):
"""
Tokenizes the prompts in the given dataset using the specified tokenizer.
Args:
dataset (DatasetDict): The dataset containing the prompts to be tokenized.
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the prompts.
max_prompt_length (int): The maximum length of the tokenized prompts.
Returns:
DatasetDict: The dataset with tokenized prompts.
"""
return dataset.map(
lambda prompts: tokenizer(
prompts,
truncation=True,
padding=False,
max_length=max_prompt_length,
),
batched=True,
input_columns=["prompt"],
remove_columns=["prompt"],
)
def compute_format_reward(completion: str) -> float:
"""
Computes a reward based on the format of the given completion string.
Args:
completion (str): The completion string to be evaluated.
Returns:
float: The reward value, 1.0 if the pattern is matched, otherwise 0.0.
"""
pattern = r"^<think>(.*?)</think>\s*<answer>.*?</answer>$"
matched = re.findall(pattern, completion)
return 1.0 if matched else 0.0
def compute_accuracy_reward(completion: str, truth: str) -> float:
"""
Computes the accuracy reward based on the provided completion and truth strings.
Args:
completion (str): The completion string containing the answer within <answer> tags.
truth (str): The ground truth string to verify the answer against.
Returns:
float: The accuracy reward, which is 1.0 if the answer matches the truth, otherwise 0.0.
"""
pattern = r"<answer>(.*?)</answer>"
matched = re.findall(pattern, completion)
reward = 0.0
if matched:
parsed_truth = parse(str(truth))
parsed_answer = parse(normalize_number(matched[0]))
if verify(parsed_truth, parsed_answer):
reward = 1.0
return reward
def compute_rewards(
token_ids: LongTensor, truths: LongTensor, tokenizer: PreTrainedTokenizer
) -> BatchRewards:
"""
Compute rewards for a batch of tokenized completions.
Args:
token_ids (LongTensor): Tensor containing token IDs for each completion.
truths (LongTensor): Tensor containing the ground truth values.
tokenizer (PreTrainedTokenizer): Tokenizer used to decode token IDs into strings.
Returns:
BatchRewards: A named tuple containing format rewards, accuracy rewards,
and total rewards for each completion.
"""
completions = tokenizer.batch_decode(token_ids, skip_special_tokens=True)
format_rewards = torch.zeros(token_ids.shape[0], device=token_ids.device)
accuracy_rewards = torch.zeros(token_ids.shape[0], device=token_ids.device)
total_rewards = torch.zeros(token_ids.shape[0], device=token_ids.device)
for idx, (completion, truth) in enumerate(zip(completions, truths)):
format_reward = compute_format_reward(completion)
accuracy_reward = compute_accuracy_reward(completion, truth.item())
format_rewards[idx] = format_reward
accuracy_rewards[idx] = accuracy_reward
total_rewards[idx] = format_reward + accuracy_reward
return BatchRewards(format_rewards, accuracy_rewards, total_rewards)
def get_mask_after_eos(ids: LongTensor, tokenizer: PreTrainedTokenizer):
"""
Generates a mask for a sequence of token IDs, masking all tokens
that appear after the end-of-sequence (eos) token.
Args:
ids (LongTensor): A tensor of token IDs with shape (batch_size, sequence_length).
tokenizer (PreTrainedTokenizer): A tokenizer object that provides the eos_token_id attribute.
Returns:
LongTensor: A tensor of the same shape as `ids` with 1s for tokens up to and includin
the eos token, and 0s for tokens after the eos token.
"""
is_eos = ids == tokenizer.eos_token_id
eos_idx = torch.full(
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=ids.device
)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=ids.device).expand(
is_eos.size(0), -1
)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
return completion_mask
def selective_log_softmax(
logits: FloatTensor, index: LongTensor
) -> FloatTensor:
"""
Computes the log softmax of the input logits selectively based on the provided indices.
This function performs the same operation as applying `log_softmax` on the logits tensor
along the last dimension and then gathering the results based on the provided indices.
However, it processes the logits row by row to save memory by leveraging PyTorch internals.
Taken from https://www.tylerromero.com/posts/2025-02-selective-log-softmax/
Args:
logits (FloatTensor): A tensor of shape (batch_size, num_classes) containing the raw
logits for each class.
index (LongTensor): A tensor of shape (batch_size, num_indices) containing the indices
of the classes for which to compute the log softmax.
Returns:
FloatTensor: A tensor of shape (batch_size, num_indices) containing the log softmax
values for the specified indices.
"""
token_logprobs = []
for logits_row, index_row in zip(logits, index):
logprobs_row = logits_row.log_softmax(dim=-1)
token_logprobs_row = torch.gather(
logprobs_row, dim=-1, index=index_row.unsqueeze(-1)
).squeeze(-1)
token_logprobs.append(token_logprobs_row)
return torch.stack(token_logprobs)
def gather_token_scores(
logits: FloatTensor, generated_ids: LongTensor
) -> FloatTensor:
"""
Gathers token scores from logits based on generated token IDs.
Args:
logits (FloatTensor): The logits output from the model. It can be a tuple of tensors or a single tensor.
generated_ids (LongTensor): The IDs of the generated tokens.
Returns:
FloatTensor: The token scores after applying a selective log softmax on the logits.
"""
if isinstance(logits, tuple):
# Stack the logits (batch_size*group_size, output_length, vocab)
logits = torch.stack(logits, axis=0).permute((1, 0, 2))
# Logsoftmax the logits
token_scores = selective_log_softmax(logits, generated_ids)
return token_scores
def compute_token_scores(
model: PreTrainedModel,
all_ids: LongTensor,
all_attention_mask: LongTensor,
generated_ids: LongTensor,
input_length: int,
batch_size: int,
group_size: int,
) -> FloatTensor:
"""
Compute token scores for a given model and input data.
Args:
model (PreTrainedModel): The pre-trained model to use for generating logits.
all_ids (LongTensor): Tensor containing input token IDs.
all_attention_mask (LongTensor): Tensor containing attention masks for the input IDs.
generated_ids (LongTensor): Tensor containing generated token IDs.
input_length (int): The length of the input sequence.
batch_size (int): The size of the batch.
group_size (int): The size of the group.
Returns:
FloatTensor: A tensor containing the computed token scores, reshaped to (batch_size, group_size, -1).
"""
logits = model(input_ids=all_ids, attention_mask=all_attention_mask).logits
logits = logits[:, input_length - 1 : -1]
scores = gather_token_scores(logits, generated_ids)
scores = scores.view(batch_size, group_size, -1)
del logits
torch.cuda.empty_cache()
return scores
def grpo(
generated_ids: LongTensor,
old_scores: FloatTensor,
current_scores: FloatTensor,
reference_scores: FloatTensor,
truths: LongTensor,
tokenizer: PreTrainedTokenizer,
epsilon: float,
beta: float,
) -> GRPOOutput:
"""
Compute the loss of Group Relative Policy Optimization (GRPO) on the given inputs.
Args:
generated_ids (LongTensor): Tensor of generated token IDs.
old_scores (FloatTensor): Tensor of old policy scores.
current_scores (FloatTensor): Tensor of current policy scores.
reference_scores (FloatTensor): Tensor of reference policy scores.
truths (LongTensor): Tensor of ground truth token IDs.
tokenizer (PreTrainedTokenizer): Tokenizer used for encoding/decoding.
epsilon (float): Clipping parameter for policy ratios.
beta (float): Weighting factor for the Kullback-Leibler divergence term.
Returns:
GRPOOutput: A dataclass containing the mean loss, format reward, accuracy reward, total reward, and KL divergence.
"""
losses = torch.zeros(generated_ids.shape[0])
kls = torch.zeros(generated_ids.shape[0])
format_rewards = torch.zeros(generated_ids.shape[0])
accuracy_rewards = torch.zeros(generated_ids.shape[0])
total_rewards = torch.zeros(generated_ids.shape[0])
for idx, (
group_ids,
group_truths,
group_old_scores,
group_current_scores,
group_reference_scores,
) in enumerate(
zip(generated_ids, truths, old_scores, current_scores, reference_scores)
):
# Compute advantages
group_rewards = compute_rewards(group_ids, group_truths, tokenizer)
mean = group_rewards.total_rewards.mean()
centered = group_rewards.total_rewards - mean
std = group_rewards.total_rewards.std()
if std < 1e-8:
advantages = torch.zeros_like(centered)
else:
advantages = centered / (std + 1e-8)
# Store the mean of each rewards for the group
format_rewards[idx] = group_rewards.format_rewards.mean()
accuracy_rewards[idx] = group_rewards.accuracy_rewards.mean()
total_rewards[idx] = group_rewards.total_rewards.mean()
# Compute the ratios
ratios = torch.exp(group_current_scores - group_old_scores)
# Compute the clipped ratios
clipped_ratios = torch.clamp(
ratios, min=1.0 - epsilon, max=1.0 + epsilon
)
# Compute kullback-leibler divergence between reference and current policy
kl = (
torch.exp(group_reference_scores - group_current_scores)
- (group_reference_scores - group_current_scores)
- 1
)
kls[idx] = kl.mean()
# Compute mean loss of the group
completion_mask = get_mask_after_eos(group_ids, tokenizer)
loss = (
torch.min(
ratios * advantages.unsqueeze(-1),
clipped_ratios * advantages.unsqueeze(-1),
)
- beta * kl
)
loss = -(loss * completion_mask).sum() / completion_mask.sum()
losses[idx] = loss
return GRPOOutput(
loss=losses.mean(),
format_reward=format_rewards.mean(),
accuracy_reward=accuracy_rewards.mean(),
total_reward=total_rewards.mean(),
kl=kls.mean(),
)
def train(
dataset: Dataset,
model: PeftModel,
tokenizer: PreTrainedTokenizer,
training_args: TrainingArguments,
) -> None:
"""
Train a language model using the GRPO (Group Relative Policy Optimization) objective.
Args:
dataset (Dataset): The dataset to be used for training.
model (PeftModel): The model to be trained.
tokenizer (PreTrainedTokenizer): The tokenizer associated with the model.
training_args (TrainingArguments): The training arguments containing hyperparameters and configurations.
"""
# Prepare the dataloader
train_dataloader = DataLoader(
dataset,
collate_fn=DataCollatorWithPadding(tokenizer),
batch_size=training_args.batch_size,
)
# Prepare the old policy
old_model = deepcopy(model)
old_model.eval()
# Prepare the metrics
running_metrics = {
"loss": 0.0,
"format_reward": 0.0,
"accuracy_reward": 0.0,
"total_reward": 0.0,
"completion_length": 0.0,
"kl": 0.0,
}
# Prepare optimizer and lr scheduler
optimizer = AdamW(
model.parameters(),
lr=training_args.learning_rate,
)
scheduler = LinearLR(
optimizer,
start_factor=1,
end_factor=0.1,
total_iters=training_args.epochs * len(train_dataloader),
)
# Let's train!
training_step = 0
for _ in range(training_args.epochs):
# Update the old policy
old_model.load_state_dict(model.state_dict(), strict=False)
for batch in tqdm(train_dataloader, total=len(train_dataloader)):
# Prepare the batch data
input_ids = batch["input_ids"].to(model.device)
attention_mask = batch["attention_mask"].to(model.device)
truths = batch["answer"].to(model.device)
effective_batch_size, input_length = (
input_ids.shape[0],
input_ids.shape[1],
)
# Generate with the old policy
all_ids = old_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=training_args.max_new_tokens,
do_sample=True,
num_return_sequences=training_args.group_size,
temperature=training_args.temperature,
)
# Pick only the completion ids (batch_size*group_size, output_length)
generated_ids = all_ids[:, input_length:]
# Prepare attention mask for computing current
# and reference logits on the generated ids
# (batch_size * group_size, input_length + output_length)
completion_mask = get_mask_after_eos(generated_ids, tokenizer)
all_attention_mask = torch.hstack(
(
attention_mask.repeat_interleave(
training_args.group_size, dim=0
),
completion_mask,
)
)
# Compute the token scores of the old policy
with torch.inference_mode(), torch.autocast(
"cuda", dtype=torch.bfloat16
):
old_scores = compute_token_scores(
old_model,
all_ids,
all_attention_mask,
generated_ids,
input_length,
effective_batch_size,
training_args.group_size,
)
# Compute the sequence scores of the current policy
with torch.autocast("cuda", dtype=torch.bfloat16):
model.eval()
current_scores = compute_token_scores(
model,
all_ids,
all_attention_mask,
generated_ids,
input_length,
effective_batch_size,
training_args.group_size,
)
model.train()
# Compute the sequence scores of the reference model
# Note that the reference model is the current policy
# with disabled LoRA since the only trainable arguments
# are those from LoRA
with torch.inference_mode(), model.disable_adapter(), torch.autocast(
"cuda", dtype=torch.bfloat16
):
model.eval()
reference_scores = compute_token_scores(
model,
all_ids,
all_attention_mask,
generated_ids,
input_length,
effective_batch_size,
training_args.group_size,
)
model.train()
# Group the generated ids (batch_size, group_size, output_length)
generated_ids = generated_ids.view(
effective_batch_size, training_args.group_size, -1
)
# Repeat the truths and group (batch_size, group_size)
truths = truths.repeat_interleave(training_args.group_size).view(
effective_batch_size, training_args.group_size
)
# Compute GRPO objective
with torch.autocast("cuda", dtype=torch.bfloat16):
grpo_output = grpo(
generated_ids,
old_scores,
current_scores,
reference_scores,
truths,
tokenizer,
training_args.grpo_epsilon,
training_args.grpo_beta,
)
# Update the current policy
grpo_output.loss.backward()
clip_grad_norm_(
model.parameters(),
training_args.gradient_max_norm,
)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# Update old policy periodically
if (training_step + 1) % training_args.update_old_after == 0:
old_model.load_state_dict(model.state_dict(), strict=False)
torch.cuda.empty_cache()
# Update log metrics
batch_metrics = {
"loss": grpo_output.loss.item(),
"format_reward": grpo_output.format_reward.item(),
"accuracy_reward": grpo_output.accuracy_reward.item(),
"total_reward": grpo_output.total_reward.item(),
"kl": grpo_output.kl.item(),
"completion_length": completion_mask.sum(-1)
.float()
.mean()
.item(),
}
running_metrics = {
key: running_metrics[key] + batch_metrics.get(key, 0)
for key in running_metrics
}
# And report them periodically
if (training_step + 1) % training_args.logging_steps == 0:
wandb.log(
{
**{
key: val / (training_step + 1)
for key, val in running_metrics.items()
},
**{"lr": scheduler.get_last_lr()[0]},
}
)
# Save the model each periodically
if (training_step + 1) % training_args.save_steps == 0:
model.save_pretrained(
f"{training_args.save_dir}_step-{training_step}"
)
# Free GPU memory at the end
del (
all_ids,
all_attention_mask,
generated_ids,
old_scores,
current_scores,
reference_scores,
grpo_output,
truths,
)
torch.cuda.empty_cache()
gc.collect()
training_step += 1
def main():
"""
Main function for training a LoRA-adapted decoder model with
GRPO to enhance reasoning skills on the GSM8K dataset.
"""
# Instantiate current policy and reference model
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "down_proj", "up_proj"],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
model = get_lora_model(load_model(model_name), lora_config)
tokenizer = load_tokenizer(model_name)
# Define training arguments
training_args = TrainingArguments(
epochs=1,
batch_size=4,
learning_rate=1e-5,
update_old_after=3000,
group_size=4,
logging_steps=10,
max_new_tokens=512,
prompt_max_length=256,
temperature=1.0,
grpo_epsilon=0.1,
grpo_beta=0.04,
gradient_max_norm=0.2,
save_steps=100,
save_dir="./grpo_qwen-2.5-1.5b_gsm8k",
)
# Load and format the dataset
instruction = (
"A conversation between User and Assistant. The user asks a question, "
"and the Assistant solves it. The assistant first thinks about the "
"reasoning process in the mind and then provides the user with the answer. "
"The reasoning process and answer are enclosed within <think> </think> and "
"<answer> </answer> tags, respectively, i.e., <think> reasoning process here "
"</think> <answer> answer here </answer>."
)
dataset = load_gsm8k()
dataset = prompt_dataset(dataset, tokenizer, instruction)
dataset = tokenize_dataset(
dataset, tokenizer, training_args.prompt_max_length
)
# Initialize wandb
wandb.login()
wandb.init(
project="GRPO-LLM-Reasoning",
config={
"model": model_name,
"dataset": "openai/gsm8k",
"lora_config": lora_config.__dict__,
"training_args": training_args.__dict__,
},
)
# Let's train!
train(dataset["train"], model, tokenizer, training_args)
# Save the model and finish logging
model.save_pretrained(f"grpo_{model_name.replace('/', '_')}_gsm8k")
wandb.finish()
if __name__ == "__main__":
main()
@jogonba2
Copy link
Author

jogonba2 commented Feb 23, 2025

"Aha moment" found after 146 steps:

Question: If one woman has a son in 9 months, how many sons have 9 women in one month?

Llama-3.2-1B-instruct: <think> In 9 months, there will be 9 * 12 = 108 women. Since each woman can have only one son, there will be 108 sons.</think><answer> The answer is 108. </answer>

GRPO-Llama-3.2-1B-instruct: <think> Since one woman has a son in 9 months, it means that in 9 months, she will have 9 sons. However, the question asks how many sons have 9 women in one month. This is a bit confusing, as it seems like the question is asking about the number of sons, but the answer is actually about the number of women. Since there are 9 women, and each woman gives birth to a son, the answer is 9 sons. But wait, the question asks about sons, not women. So, the correct answer is 9 sons. However, if we interpret the question as asking about the number of women, then the answer is 9 women. But if we interpret the question as asking about the number of sons, then the answer is 9 sons. Since the question is ambiguous, I will provide two possible answers.</think><answer>9 sons 9 women</answer>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment