-
-
Save jogonba2/468b9904a3a03727ffccfca6a423d7b1 to your computer and use it in GitHub Desktop.
import gc | |
import re | |
from copy import deepcopy | |
from dataclasses import dataclass | |
import torch | |
from datasets import Dataset, DatasetDict, load_dataset | |
from math_verify import parse, verify | |
from peft import ( | |
LoraConfig, | |
PeftModel, | |
get_peft_model, | |
prepare_model_for_kbit_training, | |
) | |
from torch import FloatTensor, LongTensor | |
from torch.nn.utils import clip_grad_norm_ | |
from torch.optim import AdamW | |
from torch.optim.lr_scheduler import LinearLR | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
DataCollatorWithPadding, | |
PreTrainedModel, | |
PreTrainedTokenizer, | |
) | |
import wandb | |
@dataclass | |
class TrainingArguments: | |
epochs: int | |
batch_size: int | |
learning_rate: float | |
update_old_after: int | |
group_size: int | |
logging_steps: int | |
max_new_tokens: int | |
prompt_max_length: int | |
temperature: float | |
grpo_epsilon: float | |
grpo_beta: float | |
gradient_max_norm: float | |
save_steps: int | |
save_dir: str | |
@dataclass | |
class BatchRewards: | |
format_rewards: FloatTensor | |
accuracy_rewards: FloatTensor | |
total_rewards: FloatTensor | |
@dataclass | |
class GRPOOutput: | |
loss: FloatTensor | |
format_reward: FloatTensor | |
accuracy_reward: FloatTensor | |
total_reward: FloatTensor | |
kl: FloatTensor | |
def load_model(model_name: str) -> PreTrainedModel: | |
""" | |
Loads a pre-trained language model with quantization and attention configs. | |
Args: | |
model_name (str): The name or path of the pre-trained model to load. | |
Returns: | |
PreTrainedModel: The loaded pre-trained model. | |
""" | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
attn_implementation="flash_attention_2", | |
) | |
model.gradient_checkpointing_enable() | |
return model | |
def get_lora_model( | |
model: PreTrainedModel, lora_config: LoraConfig | |
) -> PeftModel: | |
""" | |
Prepares and returns a LoRA model for training. | |
Args: | |
model (PreTrainedModel): The pre-trained model to be adapted with LoRA. | |
lora_config (LoraConfig): The configuration for the LoRA adaptation. | |
Returns: | |
PeftModel: The model adapted with LoRA, ready for training. | |
""" | |
model = prepare_model_for_kbit_training(model) | |
model = get_peft_model(model, lora_config) | |
model.print_trainable_parameters() | |
return model | |
def load_tokenizer(model_name: str) -> PreTrainedTokenizer: | |
""" | |
Loads the tokenizer of a pre-trained model. | |
Args: | |
model_name (str): The name or path of the pre-trained model. | |
Returns: | |
PreTrainedTokenizer: The tokenizer associated with the specified model. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
return tokenizer | |
def normalize_number(x: str) -> str: | |
""" | |
Normalize a numeric string by removing commas and periods, | |
and stripping any leading or trailing whitespace. | |
Dirty trick to simplify the task for math_verify. | |
Args: | |
x (str): The numeric string to normalize. | |
Returns: | |
str: The normalized numeric string | |
""" | |
return x.strip().replace(",", "").replace(".", "") | |
def load_gsm8k() -> DatasetDict: | |
""" | |
Loads the GSM8K dataset and preprocesses the answers. | |
Returns: | |
DatasetDict: A dictionary containing the processed GSM8K dataset. | |
""" | |
dataset = load_dataset("openai/gsm8k", "main") | |
return dataset.map( | |
lambda answers: { | |
"answer": [ | |
int(normalize_number(answer.split("####")[-1])) | |
for answer in answers | |
] | |
}, | |
input_columns=["answer"], | |
batched=True, | |
) | |
def prompt_dataset( | |
dataset: DatasetDict, tokenizer: PreTrainedTokenizer, instruction: str | |
) -> DatasetDict: | |
""" | |
Applies a chat template to each question in the dataset | |
using the provided tokenizer and instruction. | |
Args: | |
dataset (DatasetDict): The dataset containing questions to be processed. | |
tokenizer (PreTrainedTokenizer): The tokenizer used to apply the chat template. | |
instruction (str): The instruction to be included in the system role of the chat template. | |
Returns: | |
DatasetDict: The dataset with the processed prompts. | |
""" | |
return dataset.map( | |
lambda questions: { | |
"prompt": tokenizer.apply_chat_template( | |
[ | |
[ | |
{"role": "system", "content": instruction}, | |
{"role": "user", "content": question}, | |
] | |
for question in questions | |
], | |
tokenize=False, | |
add_generation_prompt=True, | |
) | |
}, | |
input_columns=["question"], | |
batched=True, | |
remove_columns=["question"], | |
) | |
def tokenize_dataset( | |
dataset: DatasetDict, tokenizer: PreTrainedTokenizer, max_prompt_length: int | |
): | |
""" | |
Tokenizes the prompts in the given dataset using the specified tokenizer. | |
Args: | |
dataset (DatasetDict): The dataset containing the prompts to be tokenized. | |
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the prompts. | |
max_prompt_length (int): The maximum length of the tokenized prompts. | |
Returns: | |
DatasetDict: The dataset with tokenized prompts. | |
""" | |
return dataset.map( | |
lambda prompts: tokenizer( | |
prompts, | |
truncation=True, | |
padding=False, | |
max_length=max_prompt_length, | |
), | |
batched=True, | |
input_columns=["prompt"], | |
remove_columns=["prompt"], | |
) | |
def compute_format_reward(completion: str) -> float: | |
""" | |
Computes a reward based on the format of the given completion string. | |
Args: | |
completion (str): The completion string to be evaluated. | |
Returns: | |
float: The reward value, 1.0 if the pattern is matched, otherwise 0.0. | |
""" | |
pattern = r"^<think>(.*?)</think>\s*<answer>.*?</answer>$" | |
matched = re.findall(pattern, completion) | |
return 1.0 if matched else 0.0 | |
def compute_accuracy_reward(completion: str, truth: str) -> float: | |
""" | |
Computes the accuracy reward based on the provided completion and truth strings. | |
Args: | |
completion (str): The completion string containing the answer within <answer> tags. | |
truth (str): The ground truth string to verify the answer against. | |
Returns: | |
float: The accuracy reward, which is 1.0 if the answer matches the truth, otherwise 0.0. | |
""" | |
pattern = r"<answer>(.*?)</answer>" | |
matched = re.findall(pattern, completion) | |
reward = 0.0 | |
if matched: | |
parsed_truth = parse(str(truth)) | |
parsed_answer = parse(normalize_number(matched[0])) | |
if verify(parsed_truth, parsed_answer): | |
reward = 1.0 | |
return reward | |
def compute_rewards( | |
token_ids: LongTensor, truths: LongTensor, tokenizer: PreTrainedTokenizer | |
) -> BatchRewards: | |
""" | |
Compute rewards for a batch of tokenized completions. | |
Args: | |
token_ids (LongTensor): Tensor containing token IDs for each completion. | |
truths (LongTensor): Tensor containing the ground truth values. | |
tokenizer (PreTrainedTokenizer): Tokenizer used to decode token IDs into strings. | |
Returns: | |
BatchRewards: A named tuple containing format rewards, accuracy rewards, | |
and total rewards for each completion. | |
""" | |
completions = tokenizer.batch_decode(token_ids, skip_special_tokens=True) | |
format_rewards = torch.zeros(token_ids.shape[0], device=token_ids.device) | |
accuracy_rewards = torch.zeros(token_ids.shape[0], device=token_ids.device) | |
total_rewards = torch.zeros(token_ids.shape[0], device=token_ids.device) | |
for idx, (completion, truth) in enumerate(zip(completions, truths)): | |
format_reward = compute_format_reward(completion) | |
accuracy_reward = compute_accuracy_reward(completion, truth.item()) | |
format_rewards[idx] = format_reward | |
accuracy_rewards[idx] = accuracy_reward | |
total_rewards[idx] = format_reward + accuracy_reward | |
return BatchRewards(format_rewards, accuracy_rewards, total_rewards) | |
def get_mask_after_eos(ids: LongTensor, tokenizer: PreTrainedTokenizer): | |
""" | |
Generates a mask for a sequence of token IDs, masking all tokens | |
that appear after the end-of-sequence (eos) token. | |
Args: | |
ids (LongTensor): A tensor of token IDs with shape (batch_size, sequence_length). | |
tokenizer (PreTrainedTokenizer): A tokenizer object that provides the eos_token_id attribute. | |
Returns: | |
LongTensor: A tensor of the same shape as `ids` with 1s for tokens up to and includin | |
the eos token, and 0s for tokens after the eos token. | |
""" | |
is_eos = ids == tokenizer.eos_token_id | |
eos_idx = torch.full( | |
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=ids.device | |
) | |
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] | |
sequence_indices = torch.arange(is_eos.size(1), device=ids.device).expand( | |
is_eos.size(0), -1 | |
) | |
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() | |
return completion_mask | |
def selective_log_softmax( | |
logits: FloatTensor, index: LongTensor | |
) -> FloatTensor: | |
""" | |
Computes the log softmax of the input logits selectively based on the provided indices. | |
This function performs the same operation as applying `log_softmax` on the logits tensor | |
along the last dimension and then gathering the results based on the provided indices. | |
However, it processes the logits row by row to save memory by leveraging PyTorch internals. | |
Taken from https://www.tylerromero.com/posts/2025-02-selective-log-softmax/ | |
Args: | |
logits (FloatTensor): A tensor of shape (batch_size, num_classes) containing the raw | |
logits for each class. | |
index (LongTensor): A tensor of shape (batch_size, num_indices) containing the indices | |
of the classes for which to compute the log softmax. | |
Returns: | |
FloatTensor: A tensor of shape (batch_size, num_indices) containing the log softmax | |
values for the specified indices. | |
""" | |
token_logprobs = [] | |
for logits_row, index_row in zip(logits, index): | |
logprobs_row = logits_row.log_softmax(dim=-1) | |
token_logprobs_row = torch.gather( | |
logprobs_row, dim=-1, index=index_row.unsqueeze(-1) | |
).squeeze(-1) | |
token_logprobs.append(token_logprobs_row) | |
return torch.stack(token_logprobs) | |
def gather_token_scores( | |
logits: FloatTensor, generated_ids: LongTensor | |
) -> FloatTensor: | |
""" | |
Gathers token scores from logits based on generated token IDs. | |
Args: | |
logits (FloatTensor): The logits output from the model. It can be a tuple of tensors or a single tensor. | |
generated_ids (LongTensor): The IDs of the generated tokens. | |
Returns: | |
FloatTensor: The token scores after applying a selective log softmax on the logits. | |
""" | |
if isinstance(logits, tuple): | |
# Stack the logits (batch_size*group_size, output_length, vocab) | |
logits = torch.stack(logits, axis=0).permute((1, 0, 2)) | |
# Logsoftmax the logits | |
token_scores = selective_log_softmax(logits, generated_ids) | |
return token_scores | |
def compute_token_scores( | |
model: PreTrainedModel, | |
all_ids: LongTensor, | |
all_attention_mask: LongTensor, | |
generated_ids: LongTensor, | |
input_length: int, | |
batch_size: int, | |
group_size: int, | |
) -> FloatTensor: | |
""" | |
Compute token scores for a given model and input data. | |
Args: | |
model (PreTrainedModel): The pre-trained model to use for generating logits. | |
all_ids (LongTensor): Tensor containing input token IDs. | |
all_attention_mask (LongTensor): Tensor containing attention masks for the input IDs. | |
generated_ids (LongTensor): Tensor containing generated token IDs. | |
input_length (int): The length of the input sequence. | |
batch_size (int): The size of the batch. | |
group_size (int): The size of the group. | |
Returns: | |
FloatTensor: A tensor containing the computed token scores, reshaped to (batch_size, group_size, -1). | |
""" | |
logits = model(input_ids=all_ids, attention_mask=all_attention_mask).logits | |
logits = logits[:, input_length - 1 : -1] | |
scores = gather_token_scores(logits, generated_ids) | |
scores = scores.view(batch_size, group_size, -1) | |
del logits | |
torch.cuda.empty_cache() | |
return scores | |
def grpo( | |
generated_ids: LongTensor, | |
old_scores: FloatTensor, | |
current_scores: FloatTensor, | |
reference_scores: FloatTensor, | |
truths: LongTensor, | |
tokenizer: PreTrainedTokenizer, | |
epsilon: float, | |
beta: float, | |
) -> GRPOOutput: | |
""" | |
Compute the loss of Group Relative Policy Optimization (GRPO) on the given inputs. | |
Args: | |
generated_ids (LongTensor): Tensor of generated token IDs. | |
old_scores (FloatTensor): Tensor of old policy scores. | |
current_scores (FloatTensor): Tensor of current policy scores. | |
reference_scores (FloatTensor): Tensor of reference policy scores. | |
truths (LongTensor): Tensor of ground truth token IDs. | |
tokenizer (PreTrainedTokenizer): Tokenizer used for encoding/decoding. | |
epsilon (float): Clipping parameter for policy ratios. | |
beta (float): Weighting factor for the Kullback-Leibler divergence term. | |
Returns: | |
GRPOOutput: A dataclass containing the mean loss, format reward, accuracy reward, total reward, and KL divergence. | |
""" | |
losses = torch.zeros(generated_ids.shape[0]) | |
kls = torch.zeros(generated_ids.shape[0]) | |
format_rewards = torch.zeros(generated_ids.shape[0]) | |
accuracy_rewards = torch.zeros(generated_ids.shape[0]) | |
total_rewards = torch.zeros(generated_ids.shape[0]) | |
for idx, ( | |
group_ids, | |
group_truths, | |
group_old_scores, | |
group_current_scores, | |
group_reference_scores, | |
) in enumerate( | |
zip(generated_ids, truths, old_scores, current_scores, reference_scores) | |
): | |
# Compute advantages | |
group_rewards = compute_rewards(group_ids, group_truths, tokenizer) | |
mean = group_rewards.total_rewards.mean() | |
centered = group_rewards.total_rewards - mean | |
std = group_rewards.total_rewards.std() | |
if std < 1e-8: | |
advantages = torch.zeros_like(centered) | |
else: | |
advantages = centered / (std + 1e-8) | |
# Store the mean of each rewards for the group | |
format_rewards[idx] = group_rewards.format_rewards.mean() | |
accuracy_rewards[idx] = group_rewards.accuracy_rewards.mean() | |
total_rewards[idx] = group_rewards.total_rewards.mean() | |
# Compute the ratios | |
ratios = torch.exp(group_current_scores - group_old_scores) | |
# Compute the clipped ratios | |
clipped_ratios = torch.clamp( | |
ratios, min=1.0 - epsilon, max=1.0 + epsilon | |
) | |
# Compute kullback-leibler divergence between reference and current policy | |
kl = ( | |
torch.exp(group_reference_scores - group_current_scores) | |
- (group_reference_scores - group_current_scores) | |
- 1 | |
) | |
kls[idx] = kl.mean() | |
# Compute mean loss of the group | |
completion_mask = get_mask_after_eos(group_ids, tokenizer) | |
loss = ( | |
torch.min( | |
ratios * advantages.unsqueeze(-1), | |
clipped_ratios * advantages.unsqueeze(-1), | |
) | |
- beta * kl | |
) | |
loss = -(loss * completion_mask).sum() / completion_mask.sum() | |
losses[idx] = loss | |
return GRPOOutput( | |
loss=losses.mean(), | |
format_reward=format_rewards.mean(), | |
accuracy_reward=accuracy_rewards.mean(), | |
total_reward=total_rewards.mean(), | |
kl=kls.mean(), | |
) | |
def train( | |
dataset: Dataset, | |
model: PeftModel, | |
tokenizer: PreTrainedTokenizer, | |
training_args: TrainingArguments, | |
) -> None: | |
""" | |
Train a language model using the GRPO (Group Relative Policy Optimization) objective. | |
Args: | |
dataset (Dataset): The dataset to be used for training. | |
model (PeftModel): The model to be trained. | |
tokenizer (PreTrainedTokenizer): The tokenizer associated with the model. | |
training_args (TrainingArguments): The training arguments containing hyperparameters and configurations. | |
""" | |
# Prepare the dataloader | |
train_dataloader = DataLoader( | |
dataset, | |
collate_fn=DataCollatorWithPadding(tokenizer), | |
batch_size=training_args.batch_size, | |
) | |
# Prepare the old policy | |
old_model = deepcopy(model) | |
old_model.eval() | |
# Prepare the metrics | |
running_metrics = { | |
"loss": 0.0, | |
"format_reward": 0.0, | |
"accuracy_reward": 0.0, | |
"total_reward": 0.0, | |
"completion_length": 0.0, | |
"kl": 0.0, | |
} | |
# Prepare optimizer and lr scheduler | |
optimizer = AdamW( | |
model.parameters(), | |
lr=training_args.learning_rate, | |
) | |
scheduler = LinearLR( | |
optimizer, | |
start_factor=1, | |
end_factor=0.1, | |
total_iters=training_args.epochs * len(train_dataloader), | |
) | |
# Let's train! | |
training_step = 0 | |
for _ in range(training_args.epochs): | |
# Update the old policy | |
old_model.load_state_dict(model.state_dict(), strict=False) | |
for batch in tqdm(train_dataloader, total=len(train_dataloader)): | |
# Prepare the batch data | |
input_ids = batch["input_ids"].to(model.device) | |
attention_mask = batch["attention_mask"].to(model.device) | |
truths = batch["answer"].to(model.device) | |
effective_batch_size, input_length = ( | |
input_ids.shape[0], | |
input_ids.shape[1], | |
) | |
# Generate with the old policy | |
all_ids = old_model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=training_args.max_new_tokens, | |
do_sample=True, | |
num_return_sequences=training_args.group_size, | |
temperature=training_args.temperature, | |
) | |
# Pick only the completion ids (batch_size*group_size, output_length) | |
generated_ids = all_ids[:, input_length:] | |
# Prepare attention mask for computing current | |
# and reference logits on the generated ids | |
# (batch_size * group_size, input_length + output_length) | |
completion_mask = get_mask_after_eos(generated_ids, tokenizer) | |
all_attention_mask = torch.hstack( | |
( | |
attention_mask.repeat_interleave( | |
training_args.group_size, dim=0 | |
), | |
completion_mask, | |
) | |
) | |
# Compute the token scores of the old policy | |
with torch.inference_mode(), torch.autocast( | |
"cuda", dtype=torch.bfloat16 | |
): | |
old_scores = compute_token_scores( | |
old_model, | |
all_ids, | |
all_attention_mask, | |
generated_ids, | |
input_length, | |
effective_batch_size, | |
training_args.group_size, | |
) | |
# Compute the sequence scores of the current policy | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
model.eval() | |
current_scores = compute_token_scores( | |
model, | |
all_ids, | |
all_attention_mask, | |
generated_ids, | |
input_length, | |
effective_batch_size, | |
training_args.group_size, | |
) | |
model.train() | |
# Compute the sequence scores of the reference model | |
# Note that the reference model is the current policy | |
# with disabled LoRA since the only trainable arguments | |
# are those from LoRA | |
with torch.inference_mode(), model.disable_adapter(), torch.autocast( | |
"cuda", dtype=torch.bfloat16 | |
): | |
model.eval() | |
reference_scores = compute_token_scores( | |
model, | |
all_ids, | |
all_attention_mask, | |
generated_ids, | |
input_length, | |
effective_batch_size, | |
training_args.group_size, | |
) | |
model.train() | |
# Group the generated ids (batch_size, group_size, output_length) | |
generated_ids = generated_ids.view( | |
effective_batch_size, training_args.group_size, -1 | |
) | |
# Repeat the truths and group (batch_size, group_size) | |
truths = truths.repeat_interleave(training_args.group_size).view( | |
effective_batch_size, training_args.group_size | |
) | |
# Compute GRPO objective | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
grpo_output = grpo( | |
generated_ids, | |
old_scores, | |
current_scores, | |
reference_scores, | |
truths, | |
tokenizer, | |
training_args.grpo_epsilon, | |
training_args.grpo_beta, | |
) | |
# Update the current policy | |
grpo_output.loss.backward() | |
clip_grad_norm_( | |
model.parameters(), | |
training_args.gradient_max_norm, | |
) | |
optimizer.step() | |
optimizer.zero_grad() | |
scheduler.step() | |
# Update old policy periodically | |
if (training_step + 1) % training_args.update_old_after == 0: | |
old_model.load_state_dict(model.state_dict(), strict=False) | |
torch.cuda.empty_cache() | |
# Update log metrics | |
batch_metrics = { | |
"loss": grpo_output.loss.item(), | |
"format_reward": grpo_output.format_reward.item(), | |
"accuracy_reward": grpo_output.accuracy_reward.item(), | |
"total_reward": grpo_output.total_reward.item(), | |
"kl": grpo_output.kl.item(), | |
"completion_length": completion_mask.sum(-1) | |
.float() | |
.mean() | |
.item(), | |
} | |
running_metrics = { | |
key: running_metrics[key] + batch_metrics.get(key, 0) | |
for key in running_metrics | |
} | |
# And report them periodically | |
if (training_step + 1) % training_args.logging_steps == 0: | |
wandb.log( | |
{ | |
**{ | |
key: val / (training_step + 1) | |
for key, val in running_metrics.items() | |
}, | |
**{"lr": scheduler.get_last_lr()[0]}, | |
} | |
) | |
# Save the model each periodically | |
if (training_step + 1) % training_args.save_steps == 0: | |
model.save_pretrained( | |
f"{training_args.save_dir}_step-{training_step}" | |
) | |
# Free GPU memory at the end | |
del ( | |
all_ids, | |
all_attention_mask, | |
generated_ids, | |
old_scores, | |
current_scores, | |
reference_scores, | |
grpo_output, | |
truths, | |
) | |
torch.cuda.empty_cache() | |
gc.collect() | |
training_step += 1 | |
def main(): | |
""" | |
Main function for training a LoRA-adapted decoder model with | |
GRPO to enhance reasoning skills on the GSM8K dataset. | |
""" | |
# Instantiate current policy and reference model | |
model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
target_modules=["q_proj", "k_proj", "v_proj", "down_proj", "up_proj"], | |
lora_dropout=0.1, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
model = get_lora_model(load_model(model_name), lora_config) | |
tokenizer = load_tokenizer(model_name) | |
# Define training arguments | |
training_args = TrainingArguments( | |
epochs=1, | |
batch_size=4, | |
learning_rate=1e-5, | |
update_old_after=3000, | |
group_size=4, | |
logging_steps=10, | |
max_new_tokens=512, | |
prompt_max_length=256, | |
temperature=1.0, | |
grpo_epsilon=0.1, | |
grpo_beta=0.04, | |
gradient_max_norm=0.2, | |
save_steps=100, | |
save_dir="./grpo_qwen-2.5-1.5b_gsm8k", | |
) | |
# Load and format the dataset | |
instruction = ( | |
"A conversation between User and Assistant. The user asks a question, " | |
"and the Assistant solves it. The assistant first thinks about the " | |
"reasoning process in the mind and then provides the user with the answer. " | |
"The reasoning process and answer are enclosed within <think> </think> and " | |
"<answer> </answer> tags, respectively, i.e., <think> reasoning process here " | |
"</think> <answer> answer here </answer>." | |
) | |
dataset = load_gsm8k() | |
dataset = prompt_dataset(dataset, tokenizer, instruction) | |
dataset = tokenize_dataset( | |
dataset, tokenizer, training_args.prompt_max_length | |
) | |
# Initialize wandb | |
wandb.login() | |
wandb.init( | |
project="GRPO-LLM-Reasoning", | |
config={ | |
"model": model_name, | |
"dataset": "openai/gsm8k", | |
"lora_config": lora_config.__dict__, | |
"training_args": training_args.__dict__, | |
}, | |
) | |
# Let's train! | |
train(dataset["train"], model, tokenizer, training_args) | |
# Save the model and finish logging | |
model.save_pretrained(f"grpo_{model_name.replace('/', '_')}_gsm8k") | |
wandb.finish() | |
if __name__ == "__main__": | |
main() |
jogonba2
commented
Feb 23, 2025
"Aha moment" found after 146 steps:
Question: If one woman has a son in 9 months, how many sons have 9 women in one month?
Llama-3.2-1B-instruct: <think> In 9 months, there will be 9 * 12 = 108 women. Since each woman can have only one son, there will be 108 sons.</think><answer> The answer is 108. </answer>
GRPO-Llama-3.2-1B-instruct: <think> Since one woman has a son in 9 months, it means that in 9 months, she will have 9 sons. However, the question asks how many sons have 9 women in one month. This is a bit confusing, as it seems like the question is asking about the number of sons, but the answer is actually about the number of women. Since there are 9 women, and each woman gives birth to a son, the answer is 9 sons. But wait, the question asks about sons, not women. So, the correct answer is 9 sons. However, if we interpret the question as asking about the number of women, then the answer is 9 women. But if we interpret the question as asking about the number of sons, then the answer is 9 sons. Since the question is ambiguous, I will provide two possible answers.</think><answer>9 sons 9 women</answer>