Skip to content

Instantly share code, notes, and snippets.

@husain-zaidi
Created January 29, 2025 17:48
Show Gist options
  • Save husain-zaidi/5afa49d839263e5258c8cefc816fbb9f to your computer and use it in GitHub Desktop.
Save husain-zaidi/5afa49d839263e5258c8cefc816fbb9f to your computer and use it in GitHub Desktop.
Training a dumb simple R1-like LLM for subtraction
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
import numpy as np
from datasets import load_from_disk, Dataset
import re
from peft import LoraConfig
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="mps",
trust_remote_code=True,
torch_dtype="auto",
low_cpu_mem_usage=True,
_attn_implementation='eager')
SYSTEM_PROMPT = """A conversation between User and Assistant. The user gives two numbers, and the Assistant returns the second minus the first number only. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>
"""
def gen_sub_dataset():
# Generate random numbers and their differences
num_samples = 8000
rng = np.random.default_rng(42) # for reproducibility
# Generate random integers between 0 and 1000
numbers1 = rng.integers(0, 1000, num_samples)
numbers2 = rng.integers(0, 1000, num_samples)
# Create the input strings and differences
input_pairs = [f"{n2} {n1}" for n1, n2 in zip(numbers1, numbers2)]
differences = numbers2 - numbers1
# Create dataset dictionary
dataset_dict = {
"input": input_pairs,
"difference": differences
}
# Create and save the dataset
dataset = Dataset.from_dict(dataset_dict)
dataset.save_to_disk("subtraction_dataset")
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
def accuracy_reward(completions, difference, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
answer_pattern = r"<answer>(.*?)</answer>"
rewards = []
for content, diff in zip(contents, difference):
try:
content = re.search(answer_pattern, content).group(1).strip()
answer = int(content)
reward = 1.0 if answer == diff else 0.0
except Exception: # if it fails for any reason, return 0.0
reward = 0.0
rewards.append(reward)
return rewards
def finetune():
dataset = load_from_disk("subtraction_dataset")
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["input"]},
]
}
dataset = dataset.map(make_conversation)
training_args = GRPOConfig(
output_dir="Qwen2-0.5B-GRPO",
learning_rate=1e-5,
logging_steps=1,
gradient_accumulation_steps=8,
max_completion_length=128,
per_device_train_batch_size = 1,
save_steps=10,
eval_steps=100
)
split_dataset = dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split_dataset['train']
test_dataset = split_dataset['test']
# Initialize the GRPO trainer
trainer = GRPOTrainer(
model=model_id,
reward_funcs=[format_reward, accuracy_reward],
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
peft_config=LoraConfig(task_type="CAUSAL_LM")
)
# Train and push the model to the Hub
trainer.train()
# Save and push to hub
trainer.save_model('./output/tunedmodel')
def call_llm():
tuned_model = AutoModelForCausalLM.from_pretrained(
"Qwen2-0.5B-GRPO/checkpoint-40",
device_map="mps",
trust_remote_code=True,
torch_dtype="auto",
low_cpu_mem_usage=True,
_attn_implementation='eager')
input_text = "5 105"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": input_text}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(tuned_model.device)
generated_ids = tuned_model.generate(
**model_inputs,
max_new_tokens=128
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
def main():
gen_sub_dataset()
finetune()
call_llm()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment