-
Star
(1,200)
You must be signed in to star a gist -
Fork
(374)
You must be signed in to fork a gist
-
-
Save willccbb/4676755236bb08cab5f4e54a0475d6fb to your computer and use it in GitHub Desktop.
# train_grpo.py | |
# | |
# See https://github.com/willccbb/verifiers for ongoing developments | |
# | |
import re | |
import torch | |
from datasets import load_dataset, Dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import LoraConfig | |
from trl import GRPOConfig, GRPOTrainer | |
# Load and prep dataset | |
SYSTEM_PROMPT = """ | |
Respond in the following format: | |
<reasoning> | |
... | |
</reasoning> | |
<answer> | |
... | |
</answer> | |
""" | |
XML_COT_FORMAT = """\ | |
<reasoning> | |
{reasoning} | |
</reasoning> | |
<answer> | |
{answer} | |
</answer> | |
""" | |
def extract_xml_answer(text: str) -> str: | |
answer = text.split("<answer>")[-1] | |
answer = answer.split("</answer>")[0] | |
return answer.strip() | |
def extract_hash_answer(text: str) -> str | None: | |
if "####" not in text: | |
return None | |
return text.split("####")[1].strip().replace(",", "").replace("$", "") | |
# uncomment middle messages for 1-shot prompting | |
def get_gsm8k_questions(split = "train") -> Dataset: | |
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore | |
data = data.map(lambda x: { # type: ignore | |
'prompt': [ | |
{'role': 'system', 'content': SYSTEM_PROMPT}, | |
#{'role': 'user', 'content': 'What is the largest single-digit prime number?'}, | |
#{'role': 'assistant', 'content': XML_COT_FORMAT.format( | |
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.", | |
# answer="7" | |
#)}, | |
{'role': 'user', 'content': x['question']} | |
], | |
'answer': extract_hash_answer(x['answer']) | |
}) # type: ignore | |
return data # type: ignore | |
dataset = get_gsm8k_questions() | |
# Reward functions | |
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: | |
responses = [completion[0]['content'] for completion in completions] | |
q = prompts[0][-1]['content'] | |
extracted_responses = [extract_xml_answer(r) for r in responses] | |
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") | |
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] | |
def int_reward_func(completions, **kwargs) -> list[float]: | |
responses = [completion[0]['content'] for completion in completions] | |
extracted_responses = [extract_xml_answer(r) for r in responses] | |
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] | |
def strict_format_reward_func(completions, **kwargs) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" | |
responses = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] | |
return [0.5 if match else 0.0 for match in matches] | |
def soft_format_reward_func(completions, **kwargs) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" | |
responses = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] | |
return [0.5 if match else 0.0 for match in matches] | |
def count_xml(text) -> float: | |
count = 0.0 | |
if text.count("<reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n</reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n<answer>\n") == 1: | |
count += 0.125 | |
count -= len(text.split("\n</answer>\n")[-1])*0.001 | |
if text.count("\n</answer>") == 1: | |
count += 0.125 | |
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001 | |
return count | |
def xmlcount_reward_func(completions, **kwargs) -> list[float]: | |
contents = [completion[0]["content"] for completion in completions] | |
return [count_xml(c) for c in contents] | |
#model_name = "meta-llama/Llama-3.2-1B-Instruct" | |
model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
if "Llama" in model_name: | |
output_dir = "outputs/Llama-1B-GRPO" | |
run_name = "Llama-1B-GRPO-gsm8k" | |
else: | |
output_dir="outputs/Qwen-1.5B-GRPO" | |
run_name="Qwen-1.5B-GRPO-gsm8k" | |
training_args = GRPOConfig( | |
output_dir=output_dir, | |
run_name=run_name, | |
learning_rate=5e-6, | |
adam_beta1 = 0.9, | |
adam_beta2 = 0.99, | |
weight_decay = 0.1, | |
warmup_ratio = 0.1, | |
lr_scheduler_type='cosine', | |
logging_steps=1, | |
bf16=True, | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=4, | |
num_generations=16, | |
max_prompt_length=256, | |
max_completion_length=786, | |
num_train_epochs=1, | |
save_steps=100, | |
max_grad_norm=0.1, | |
report_to="wandb", | |
log_on_each_node=False, | |
) | |
peft_config = LoraConfig( | |
r=16, | |
lora_alpha=64, | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], | |
task_type="CAUSAL_LM", | |
lora_dropout=0.05, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2", | |
device_map=None | |
).to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
# use peft at your own risk; not working for me with multi-GPU training | |
trainer = GRPOTrainer( | |
model=model, | |
processing_class=tokenizer, | |
reward_funcs=[ | |
xmlcount_reward_func, | |
soft_format_reward_func, | |
strict_format_reward_func, | |
int_reward_func, | |
correctness_reward_func], | |
args=training_args, | |
train_dataset=dataset, | |
#peft_config=peft_config | |
) | |
trainer.train() |
being able to log individual rewards is pretty useful for debugging imo
consolidating them into one shouldn't affect actual training dynamics though
IMHO separate, additive rewards introduce a lot of repetition (e.g., parsing responses) and limit creativity in reward design, e.g., I may want the formatting reward to be a gate for the others, as I may not even want to evaluate a response if the formatting is wrong.
definitely get creative with it! nothing wrong with using if statements + multiplication in your reward functions
Does it also work on smaller models like >3B params model?
When training on the GPU with qwen model, I encountered the error: " probability tensor contains either
inf
,nan
or element < 0"
Hi @fsxbhyy, did you load the model in torch.bfloat16
? I used to encounter such issue when I loaded models in torch.float16
instead of bfloat
. I guess float16
in this context leads to numerical instability, leading to NaN probs. Hope this helps!
I got the same problem. I trained 7B with batch_size == 1, but it just keep reporting oom.
@harrywoo @Tuziking I had the same problem. I then noticed that these values are actually huge for most cases:
max_prompt_length=256,
max_completion_length=786,
786 generated tokens to process per generation requires a lot of memory, especially if your group size is large. Try to set this to 150 or 250 and see if it reduces memory usage. Hope this helps!
@willccbb, why would we prefer separate reward functions instead of having a single unified one in GRPO?