Created
January 29, 2025 17:48
-
-
Save husain-zaidi/5afa49d839263e5258c8cefc816fbb9f to your computer and use it in GitHub Desktop.
Training a dumb simple R1-like LLM for subtraction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from trl import GRPOConfig, GRPOTrainer | |
import numpy as np | |
from datasets import load_from_disk, Dataset | |
import re | |
from peft import LoraConfig | |
model_id = "Qwen/Qwen2.5-0.5B-Instruct" | |
import os | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="mps", | |
trust_remote_code=True, | |
torch_dtype="auto", | |
low_cpu_mem_usage=True, | |
_attn_implementation='eager') | |
SYSTEM_PROMPT = """A conversation between User and Assistant. The user gives two numbers, and the Assistant returns the second minus the first number only. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer> | |
""" | |
def gen_sub_dataset(): | |
# Generate random numbers and their differences | |
num_samples = 8000 | |
rng = np.random.default_rng(42) # for reproducibility | |
# Generate random integers between 0 and 1000 | |
numbers1 = rng.integers(0, 1000, num_samples) | |
numbers2 = rng.integers(0, 1000, num_samples) | |
# Create the input strings and differences | |
input_pairs = [f"{n2} {n1}" for n1, n2 in zip(numbers1, numbers2)] | |
differences = numbers2 - numbers1 | |
# Create dataset dictionary | |
dataset_dict = { | |
"input": input_pairs, | |
"difference": differences | |
} | |
# Create and save the dataset | |
dataset = Dataset.from_dict(dataset_dict) | |
dataset.save_to_disk("subtraction_dataset") | |
def format_reward(completions, **kwargs): | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"^<think>.*?</think><answer>.*?</answer>$" | |
completion_contents = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, content) for content in completion_contents] | |
return [1.0 if match else 0.0 for match in matches] | |
def accuracy_reward(completions, difference, **kwargs): | |
"""Reward function that checks if the completion is the same as the ground truth.""" | |
contents = [completion[0]["content"] for completion in completions] | |
answer_pattern = r"<answer>(.*?)</answer>" | |
rewards = [] | |
for content, diff in zip(contents, difference): | |
try: | |
content = re.search(answer_pattern, content).group(1).strip() | |
answer = int(content) | |
reward = 1.0 if answer == diff else 0.0 | |
except Exception: # if it fails for any reason, return 0.0 | |
reward = 0.0 | |
rewards.append(reward) | |
return rewards | |
def finetune(): | |
dataset = load_from_disk("subtraction_dataset") | |
# Format into conversation | |
def make_conversation(example): | |
return { | |
"prompt": [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": example["input"]}, | |
] | |
} | |
dataset = dataset.map(make_conversation) | |
training_args = GRPOConfig( | |
output_dir="Qwen2-0.5B-GRPO", | |
learning_rate=1e-5, | |
logging_steps=1, | |
gradient_accumulation_steps=8, | |
max_completion_length=128, | |
per_device_train_batch_size = 1, | |
save_steps=10, | |
eval_steps=100 | |
) | |
split_dataset = dataset.train_test_split(test_size=0.2, seed=42) | |
train_dataset = split_dataset['train'] | |
test_dataset = split_dataset['test'] | |
# Initialize the GRPO trainer | |
trainer = GRPOTrainer( | |
model=model_id, | |
reward_funcs=[format_reward, accuracy_reward], | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=test_dataset, | |
peft_config=LoraConfig(task_type="CAUSAL_LM") | |
) | |
# Train and push the model to the Hub | |
trainer.train() | |
# Save and push to hub | |
trainer.save_model('./output/tunedmodel') | |
def call_llm(): | |
tuned_model = AutoModelForCausalLM.from_pretrained( | |
"Qwen2-0.5B-GRPO/checkpoint-40", | |
device_map="mps", | |
trust_remote_code=True, | |
torch_dtype="auto", | |
low_cpu_mem_usage=True, | |
_attn_implementation='eager') | |
input_text = "5 105" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
messages = [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": input_text} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(tuned_model.device) | |
generated_ids = tuned_model.generate( | |
**model_inputs, | |
max_new_tokens=128 | |
) | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
print(response) | |
def main(): | |
gen_sub_dataset() | |
finetune() | |
call_llm() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment