Last active
March 7, 2025 18:04
-
-
Save jogonba2/468b9904a3a03727ffccfca6a423d7b1 to your computer and use it in GitHub Desktop.
Group Relative Policy Optimization from DeepSeek R1-zero just for fun
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
import re | |
from copy import deepcopy | |
from dataclasses import dataclass | |
import torch | |
from datasets import Dataset, DatasetDict, load_dataset | |
from math_verify import parse, verify | |
from peft import ( | |
LoraConfig, | |
PeftModel, | |
get_peft_model, | |
prepare_model_for_kbit_training, | |
) | |
from torch import FloatTensor, LongTensor | |
from torch.nn.utils import clip_grad_norm_ | |
from torch.optim import AdamW | |
from torch.optim.lr_scheduler import LinearLR | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
DataCollatorWithPadding, | |
PreTrainedModel, | |
PreTrainedTokenizer, | |
) | |
import wandb | |
@dataclass | |
class TrainingArguments: | |
epochs: int | |
batch_size: int | |
learning_rate: float | |
update_old_after: int | |
group_size: int | |
logging_steps: int | |
max_new_tokens: int | |
prompt_max_length: int | |
temperature: float | |
grpo_epsilon: float | |
grpo_beta: float | |
gradient_max_norm: float | |
save_steps: int | |
save_dir: str | |
@dataclass | |
class BatchRewards: | |
format_rewards: FloatTensor | |
accuracy_rewards: FloatTensor | |
total_rewards: FloatTensor | |
@dataclass | |
class GRPOOutput: | |
loss: FloatTensor | |
format_reward: FloatTensor | |
accuracy_reward: FloatTensor | |
total_reward: FloatTensor | |
kl: FloatTensor | |
def load_model(model_name: str) -> PreTrainedModel: | |
""" | |
Loads a pre-trained language model with quantization and attention configs. | |
Args: | |
model_name (str): The name or path of the pre-trained model to load. | |
Returns: | |
PreTrainedModel: The loaded pre-trained model. | |
""" | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
attn_implementation="flash_attention_2", | |
) | |
model.gradient_checkpointing_enable() | |
return model | |
def get_lora_model( | |
model: PreTrainedModel, lora_config: LoraConfig | |
) -> PeftModel: | |
""" | |
Prepares and returns a LoRA model for training. | |
Args: | |
model (PreTrainedModel): The pre-trained model to be adapted with LoRA. | |
lora_config (LoraConfig): The configuration for the LoRA adaptation. | |
Returns: | |
PeftModel: The model adapted with LoRA, ready for training. | |
""" | |
model = prepare_model_for_kbit_training(model) | |
model = get_peft_model(model, lora_config) | |
model.print_trainable_parameters() | |
return model | |
def load_tokenizer(model_name: str) -> PreTrainedTokenizer: | |
""" | |
Loads the tokenizer of a pre-trained model. | |
Args: | |
model_name (str): The name or path of the pre-trained model. | |
Returns: | |
PreTrainedTokenizer: The tokenizer associated with the specified model. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
return tokenizer | |
def normalize_number(x: str) -> str: | |
""" | |
Normalize a numeric string by removing commas and periods, | |
and stripping any leading or trailing whitespace. | |
Dirty trick to simplify the task for math_verify. | |
Args: | |
x (str): The numeric string to normalize. | |
Returns: | |
str: The normalized numeric string | |
""" | |
return x.strip().replace(",", "").replace(".", "") | |
def load_gsm8k() -> DatasetDict: | |
""" | |
Loads the GSM8K dataset and preprocesses the answers. | |
Returns: | |
DatasetDict: A dictionary containing the processed GSM8K dataset. | |
""" | |
dataset = load_dataset("openai/gsm8k", "main") | |
return dataset.map( | |
lambda answers: { | |
"answer": [ | |
int(normalize_number(answer.split("####")[-1])) | |
for answer in answers | |
] | |
}, | |
input_columns=["answer"], | |
batched=True, | |
) | |
def prompt_dataset( | |
dataset: DatasetDict, tokenizer: PreTrainedTokenizer, instruction: str | |
) -> DatasetDict: | |
""" | |
Applies a chat template to each question in the dataset | |
using the provided tokenizer and instruction. | |
Args: | |
dataset (DatasetDict): The dataset containing questions to be processed. | |
tokenizer (PreTrainedTokenizer): The tokenizer used to apply the chat template. | |
instruction (str): The instruction to be included in the system role of the chat template. | |
Returns: | |
DatasetDict: The dataset with the processed prompts. | |
""" | |
return dataset.map( | |
lambda questions: { | |
"prompt": tokenizer.apply_chat_template( | |
[ | |
[ | |
{"role": "system", "content": instruction}, | |
{"role": "user", "content": question}, | |
] | |
for question in questions | |
], | |
tokenize=False, | |
add_generation_prompt=True, | |
) | |
}, | |
input_columns=["question"], | |
batched=True, | |
remove_columns=["question"], | |
) | |
def tokenize_dataset( | |
dataset: DatasetDict, tokenizer: PreTrainedTokenizer, max_prompt_length: int | |
): | |
""" | |
Tokenizes the prompts in the given dataset using the specified tokenizer. | |
Args: | |
dataset (DatasetDict): The dataset containing the prompts to be tokenized. | |
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the prompts. | |
max_prompt_length (int): The maximum length of the tokenized prompts. | |
Returns: | |
DatasetDict: The dataset with tokenized prompts. | |
""" | |
return dataset.map( | |
lambda prompts: tokenizer( | |
prompts, | |
truncation=True, | |
padding=False, | |
max_length=max_prompt_length, | |
), | |
batched=True, | |
input_columns=["prompt"], | |
remove_columns=["prompt"], | |
) | |
def compute_format_reward(completion: str) -> float: | |
""" | |
Computes a reward based on the format of the given completion string. | |
Args: | |
completion (str): The completion string to be evaluated. | |
Returns: | |
float: The reward value, 1.0 if the pattern is matched, otherwise 0.0. | |
""" | |
pattern = r"^<think>(.*?)</think>\s*<answer>.*?</answer>$" | |
matched = re.findall(pattern, completion) | |
return 1.0 if matched else 0.0 | |
def compute_accuracy_reward(completion: str, truth: str) -> float: | |
""" | |
Computes the accuracy reward based on the provided completion and truth strings. | |
Args: | |
completion (str): The completion string containing the answer within <answer> tags. | |
truth (str): The ground truth string to verify the answer against. | |
Returns: | |
float: The accuracy reward, which is 1.0 if the answer matches the truth, otherwise 0.0. | |
""" | |
pattern = r"<answer>(.*?)</answer>" | |
matched = re.findall(pattern, completion) | |
reward = 0.0 | |
if matched: | |
parsed_truth = parse(str(truth)) | |
parsed_answer = parse(normalize_number(matched[0])) | |
if verify(parsed_truth, parsed_answer): | |
reward = 1.0 | |
return reward | |
def compute_rewards( | |
token_ids: LongTensor, truths: LongTensor, tokenizer: PreTrainedTokenizer | |
) -> BatchRewards: | |
""" | |
Compute rewards for a batch of tokenized completions. | |
Args: | |
token_ids (LongTensor): Tensor containing token IDs for each completion. | |
truths (LongTensor): Tensor containing the ground truth values. | |
tokenizer (PreTrainedTokenizer): Tokenizer used to decode token IDs into strings. | |
Returns: | |
BatchRewards: A named tuple containing format rewards, accuracy rewards, | |
and total rewards for each completion. | |
""" | |
completions = tokenizer.batch_decode(token_ids, skip_special_tokens=True) | |
format_rewards = torch.zeros(token_ids.shape[0], device=token_ids.device) | |
accuracy_rewards = torch.zeros(token_ids.shape[0], device=token_ids.device) | |
total_rewards = torch.zeros(token_ids.shape[0], device=token_ids.device) | |
for idx, (completion, truth) in enumerate(zip(completions, truths)): | |
format_reward = compute_format_reward(completion) | |
accuracy_reward = compute_accuracy_reward(completion, truth.item()) | |
format_rewards[idx] = format_reward | |
accuracy_rewards[idx] = accuracy_reward | |
total_rewards[idx] = format_reward + accuracy_reward | |
return BatchRewards(format_rewards, accuracy_rewards, total_rewards) | |
def get_mask_after_eos(ids: LongTensor, tokenizer: PreTrainedTokenizer): | |
""" | |
Generates a mask for a sequence of token IDs, masking all tokens | |
that appear after the end-of-sequence (eos) token. | |
Args: | |
ids (LongTensor): A tensor of token IDs with shape (batch_size, sequence_length). | |
tokenizer (PreTrainedTokenizer): A tokenizer object that provides the eos_token_id attribute. | |
Returns: | |
LongTensor: A tensor of the same shape as `ids` with 1s for tokens up to and includin | |
the eos token, and 0s for tokens after the eos token. | |
""" | |
is_eos = ids == tokenizer.eos_token_id | |
eos_idx = torch.full( | |
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=ids.device | |
) | |
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] | |
sequence_indices = torch.arange(is_eos.size(1), device=ids.device).expand( | |
is_eos.size(0), -1 | |
) | |
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() | |
return completion_mask | |
def selective_log_softmax( | |
logits: FloatTensor, index: LongTensor | |
) -> FloatTensor: | |
""" | |
Computes the log softmax of the input logits selectively based on the provided indices. | |
This function performs the same operation as applying `log_softmax` on the logits tensor | |
along the last dimension and then gathering the results based on the provided indices. | |
However, it processes the logits row by row to save memory by leveraging PyTorch internals. | |
Taken from https://www.tylerromero.com/posts/2025-02-selective-log-softmax/ | |
Args: | |
logits (FloatTensor): A tensor of shape (batch_size, num_classes) containing the raw | |
logits for each class. | |
index (LongTensor): A tensor of shape (batch_size, num_indices) containing the indices | |
of the classes for which to compute the log softmax. | |
Returns: | |
FloatTensor: A tensor of shape (batch_size, num_indices) containing the log softmax | |
values for the specified indices. | |
""" | |
token_logprobs = [] | |
for logits_row, index_row in zip(logits, index): | |
logprobs_row = logits_row.log_softmax(dim=-1) | |
token_logprobs_row = torch.gather( | |
logprobs_row, dim=-1, index=index_row.unsqueeze(-1) | |
).squeeze(-1) | |
token_logprobs.append(token_logprobs_row) | |
return torch.stack(token_logprobs) | |
def gather_token_scores( | |
logits: FloatTensor, generated_ids: LongTensor | |
) -> FloatTensor: | |
""" | |
Gathers token scores from logits based on generated token IDs. | |
Args: | |
logits (FloatTensor): The logits output from the model. It can be a tuple of tensors or a single tensor. | |
generated_ids (LongTensor): The IDs of the generated tokens. | |
Returns: | |
FloatTensor: The token scores after applying a selective log softmax on the logits. | |
""" | |
if isinstance(logits, tuple): | |
# Stack the logits (batch_size*group_size, output_length, vocab) | |
logits = torch.stack(logits, axis=0).permute((1, 0, 2)) | |
# Logsoftmax the logits | |
token_scores = selective_log_softmax(logits, generated_ids) | |
return token_scores | |
def compute_token_scores( | |
model: PreTrainedModel, | |
all_ids: LongTensor, | |
all_attention_mask: LongTensor, | |
generated_ids: LongTensor, | |
input_length: int, | |
batch_size: int, | |
group_size: int, | |
) -> FloatTensor: | |
""" | |
Compute token scores for a given model and input data. | |
Args: | |
model (PreTrainedModel): The pre-trained model to use for generating logits. | |
all_ids (LongTensor): Tensor containing input token IDs. | |
all_attention_mask (LongTensor): Tensor containing attention masks for the input IDs. | |
generated_ids (LongTensor): Tensor containing generated token IDs. | |
input_length (int): The length of the input sequence. | |
batch_size (int): The size of the batch. | |
group_size (int): The size of the group. | |
Returns: | |
FloatTensor: A tensor containing the computed token scores, reshaped to (batch_size, group_size, -1). | |
""" | |
logits = model(input_ids=all_ids, attention_mask=all_attention_mask).logits | |
logits = logits[:, input_length - 1 : -1] | |
scores = gather_token_scores(logits, generated_ids) | |
scores = scores.view(batch_size, group_size, -1) | |
del logits | |
torch.cuda.empty_cache() | |
return scores | |
def grpo( | |
generated_ids: LongTensor, | |
old_scores: FloatTensor, | |
current_scores: FloatTensor, | |
reference_scores: FloatTensor, | |
truths: LongTensor, | |
tokenizer: PreTrainedTokenizer, | |
epsilon: float, | |
beta: float, | |
) -> GRPOOutput: | |
""" | |
Compute the loss of Group Relative Policy Optimization (GRPO) on the given inputs. | |
Args: | |
generated_ids (LongTensor): Tensor of generated token IDs. | |
old_scores (FloatTensor): Tensor of old policy scores. | |
current_scores (FloatTensor): Tensor of current policy scores. | |
reference_scores (FloatTensor): Tensor of reference policy scores. | |
truths (LongTensor): Tensor of ground truth token IDs. | |
tokenizer (PreTrainedTokenizer): Tokenizer used for encoding/decoding. | |
epsilon (float): Clipping parameter for policy ratios. | |
beta (float): Weighting factor for the Kullback-Leibler divergence term. | |
Returns: | |
GRPOOutput: A dataclass containing the mean loss, format reward, accuracy reward, total reward, and KL divergence. | |
""" | |
losses = torch.zeros(generated_ids.shape[0]) | |
kls = torch.zeros(generated_ids.shape[0]) | |
format_rewards = torch.zeros(generated_ids.shape[0]) | |
accuracy_rewards = torch.zeros(generated_ids.shape[0]) | |
total_rewards = torch.zeros(generated_ids.shape[0]) | |
for idx, ( | |
group_ids, | |
group_truths, | |
group_old_scores, | |
group_current_scores, | |
group_reference_scores, | |
) in enumerate( | |
zip(generated_ids, truths, old_scores, current_scores, reference_scores) | |
): | |
# Compute advantages | |
group_rewards = compute_rewards(group_ids, group_truths, tokenizer) | |
mean = group_rewards.total_rewards.mean() | |
centered = group_rewards.total_rewards - mean | |
std = group_rewards.total_rewards.std() | |
if std < 1e-8: | |
advantages = torch.zeros_like(centered) | |
else: | |
advantages = centered / (std + 1e-8) | |
# Store the mean of each rewards for the group | |
format_rewards[idx] = group_rewards.format_rewards.mean() | |
accuracy_rewards[idx] = group_rewards.accuracy_rewards.mean() | |
total_rewards[idx] = group_rewards.total_rewards.mean() | |
# Compute the ratios | |
ratios = torch.exp(group_current_scores - group_old_scores) | |
# Compute the clipped ratios | |
clipped_ratios = torch.clamp( | |
ratios, min=1.0 - epsilon, max=1.0 + epsilon | |
) | |
# Compute kullback-leibler divergence between reference and current policy | |
kl = ( | |
torch.exp(group_reference_scores - group_current_scores) | |
- (group_reference_scores - group_current_scores) | |
- 1 | |
) | |
kls[idx] = kl.mean() | |
# Compute mean loss of the group | |
completion_mask = get_mask_after_eos(group_ids, tokenizer) | |
loss = ( | |
torch.min( | |
ratios * advantages.unsqueeze(-1), | |
clipped_ratios * advantages.unsqueeze(-1), | |
) | |
- beta * kl | |
) | |
loss = -(loss * completion_mask).sum() / completion_mask.sum() | |
losses[idx] = loss | |
return GRPOOutput( | |
loss=losses.mean(), | |
format_reward=format_rewards.mean(), | |
accuracy_reward=accuracy_rewards.mean(), | |
total_reward=total_rewards.mean(), | |
kl=kls.mean(), | |
) | |
def train( | |
dataset: Dataset, | |
model: PeftModel, | |
tokenizer: PreTrainedTokenizer, | |
training_args: TrainingArguments, | |
) -> None: | |
""" | |
Train a language model using the GRPO (Group Relative Policy Optimization) objective. | |
Args: | |
dataset (Dataset): The dataset to be used for training. | |
model (PeftModel): The model to be trained. | |
tokenizer (PreTrainedTokenizer): The tokenizer associated with the model. | |
training_args (TrainingArguments): The training arguments containing hyperparameters and configurations. | |
""" | |
# Prepare the dataloader | |
train_dataloader = DataLoader( | |
dataset, | |
collate_fn=DataCollatorWithPadding(tokenizer), | |
batch_size=training_args.batch_size, | |
) | |
# Prepare the old policy | |
old_model = deepcopy(model) | |
old_model.eval() | |
# Prepare the metrics | |
running_metrics = { | |
"loss": 0.0, | |
"format_reward": 0.0, | |
"accuracy_reward": 0.0, | |
"total_reward": 0.0, | |
"completion_length": 0.0, | |
"kl": 0.0, | |
} | |
# Prepare optimizer and lr scheduler | |
optimizer = AdamW( | |
model.parameters(), | |
lr=training_args.learning_rate, | |
) | |
scheduler = LinearLR( | |
optimizer, | |
start_factor=1, | |
end_factor=0.1, | |
total_iters=training_args.epochs * len(train_dataloader), | |
) | |
# Let's train! | |
training_step = 0 | |
for _ in range(training_args.epochs): | |
# Update the old policy | |
old_model.load_state_dict(model.state_dict(), strict=False) | |
for batch in tqdm(train_dataloader, total=len(train_dataloader)): | |
# Prepare the batch data | |
input_ids = batch["input_ids"].to(model.device) | |
attention_mask = batch["attention_mask"].to(model.device) | |
truths = batch["answer"].to(model.device) | |
effective_batch_size, input_length = ( | |
input_ids.shape[0], | |
input_ids.shape[1], | |
) | |
# Generate with the old policy | |
all_ids = old_model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=training_args.max_new_tokens, | |
do_sample=True, | |
num_return_sequences=training_args.group_size, | |
temperature=training_args.temperature, | |
) | |
# Pick only the completion ids (batch_size*group_size, output_length) | |
generated_ids = all_ids[:, input_length:] | |
# Prepare attention mask for computing current | |
# and reference logits on the generated ids | |
# (batch_size * group_size, input_length + output_length) | |
completion_mask = get_mask_after_eos(generated_ids, tokenizer) | |
all_attention_mask = torch.hstack( | |
( | |
attention_mask.repeat_interleave( | |
training_args.group_size, dim=0 | |
), | |
completion_mask, | |
) | |
) | |
# Compute the token scores of the old policy | |
with torch.inference_mode(), torch.autocast( | |
"cuda", dtype=torch.bfloat16 | |
): | |
old_scores = compute_token_scores( | |
old_model, | |
all_ids, | |
all_attention_mask, | |
generated_ids, | |
input_length, | |
effective_batch_size, | |
training_args.group_size, | |
) | |
# Compute the sequence scores of the current policy | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
model.eval() | |
current_scores = compute_token_scores( | |
model, | |
all_ids, | |
all_attention_mask, | |
generated_ids, | |
input_length, | |
effective_batch_size, | |
training_args.group_size, | |
) | |
model.train() | |
# Compute the sequence scores of the reference model | |
# Note that the reference model is the current policy | |
# with disabled LoRA since the only trainable arguments | |
# are those from LoRA | |
with torch.inference_mode(), model.disable_adapter(), torch.autocast( | |
"cuda", dtype=torch.bfloat16 | |
): | |
model.eval() | |
reference_scores = compute_token_scores( | |
model, | |
all_ids, | |
all_attention_mask, | |
generated_ids, | |
input_length, | |
effective_batch_size, | |
training_args.group_size, | |
) | |
model.train() | |
# Group the generated ids (batch_size, group_size, output_length) | |
generated_ids = generated_ids.view( | |
effective_batch_size, training_args.group_size, -1 | |
) | |
# Repeat the truths and group (batch_size, group_size) | |
truths = truths.repeat_interleave(training_args.group_size).view( | |
effective_batch_size, training_args.group_size | |
) | |
# Compute GRPO objective | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
grpo_output = grpo( | |
generated_ids, | |
old_scores, | |
current_scores, | |
reference_scores, | |
truths, | |
tokenizer, | |
training_args.grpo_epsilon, | |
training_args.grpo_beta, | |
) | |
# Update the current policy | |
grpo_output.loss.backward() | |
clip_grad_norm_( | |
model.parameters(), | |
training_args.gradient_max_norm, | |
) | |
optimizer.step() | |
optimizer.zero_grad() | |
scheduler.step() | |
# Update old policy periodically | |
if (training_step + 1) % training_args.update_old_after == 0: | |
old_model.load_state_dict(model.state_dict(), strict=False) | |
torch.cuda.empty_cache() | |
# Update log metrics | |
batch_metrics = { | |
"loss": grpo_output.loss.item(), | |
"format_reward": grpo_output.format_reward.item(), | |
"accuracy_reward": grpo_output.accuracy_reward.item(), | |
"total_reward": grpo_output.total_reward.item(), | |
"kl": grpo_output.kl.item(), | |
"completion_length": completion_mask.sum(-1) | |
.float() | |
.mean() | |
.item(), | |
} | |
running_metrics = { | |
key: running_metrics[key] + batch_metrics.get(key, 0) | |
for key in running_metrics | |
} | |
# And report them periodically | |
if (training_step + 1) % training_args.logging_steps == 0: | |
wandb.log( | |
{ | |
**{ | |
key: val / (training_step + 1) | |
for key, val in running_metrics.items() | |
}, | |
**{"lr": scheduler.get_last_lr()[0]}, | |
} | |
) | |
# Save the model each periodically | |
if (training_step + 1) % training_args.save_steps == 0: | |
model.save_pretrained( | |
f"{training_args.save_dir}_step-{training_step}" | |
) | |
# Free GPU memory at the end | |
del ( | |
all_ids, | |
all_attention_mask, | |
generated_ids, | |
old_scores, | |
current_scores, | |
reference_scores, | |
grpo_output, | |
truths, | |
) | |
torch.cuda.empty_cache() | |
gc.collect() | |
training_step += 1 | |
def main(): | |
""" | |
Main function for training a LoRA-adapted decoder model with | |
GRPO to enhance reasoning skills on the GSM8K dataset. | |
""" | |
# Instantiate current policy and reference model | |
model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
target_modules=["q_proj", "k_proj", "v_proj", "down_proj", "up_proj"], | |
lora_dropout=0.1, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
model = get_lora_model(load_model(model_name), lora_config) | |
tokenizer = load_tokenizer(model_name) | |
# Define training arguments | |
training_args = TrainingArguments( | |
epochs=1, | |
batch_size=4, | |
learning_rate=1e-5, | |
update_old_after=3000, | |
group_size=4, | |
logging_steps=10, | |
max_new_tokens=512, | |
prompt_max_length=256, | |
temperature=1.0, | |
grpo_epsilon=0.1, | |
grpo_beta=0.04, | |
gradient_max_norm=0.2, | |
save_steps=100, | |
save_dir="./grpo_qwen-2.5-1.5b_gsm8k", | |
) | |
# Load and format the dataset | |
instruction = ( | |
"A conversation between User and Assistant. The user asks a question, " | |
"and the Assistant solves it. The assistant first thinks about the " | |
"reasoning process in the mind and then provides the user with the answer. " | |
"The reasoning process and answer are enclosed within <think> </think> and " | |
"<answer> </answer> tags, respectively, i.e., <think> reasoning process here " | |
"</think> <answer> answer here </answer>." | |
) | |
dataset = load_gsm8k() | |
dataset = prompt_dataset(dataset, tokenizer, instruction) | |
dataset = tokenize_dataset( | |
dataset, tokenizer, training_args.prompt_max_length | |
) | |
# Initialize wandb | |
wandb.login() | |
wandb.init( | |
project="GRPO-LLM-Reasoning", | |
config={ | |
"model": model_name, | |
"dataset": "openai/gsm8k", | |
"lora_config": lora_config.__dict__, | |
"training_args": training_args.__dict__, | |
}, | |
) | |
# Let's train! | |
train(dataset["train"], model, tokenizer, training_args) | |
# Save the model and finish logging | |
model.save_pretrained(f"grpo_{model_name.replace('/', '_')}_gsm8k") | |
wandb.finish() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
"Aha moment" found after 146 steps: