Skip to content

Instantly share code, notes, and snippets.

@peterc
Last active February 20, 2025 16:05
Show Gist options
  • Save peterc/ef31b5c3194d3f82a49355d99392ec0f to your computer and use it in GitHub Desktop.
Save peterc/ef31b5c3194d3f82a49355d99392ec0f to your computer and use it in GitHub Desktop.
Python scripts to fine-tune Qwen 1.5B slightly to follow a certain requested output format

On Runpod with latest PyTorch image (2.4.0) with a GPU > 32GB VRAM (e.g. NVIDIA A100 80GB PCIe).

ssh in and:

apt update -y
apt install -y nano screen git
pip install git+https://github.com/huggingface/trl.git accelerate transformers datasets peft wandb tqdm ninja flash-attn

Note: Some of these dependencies aren't needed, they're just my personal preference.)

Then:

mkdir finetuned_model
python fine.py
python testinfer.py

This fine tuning uses about 32GB of VRAM. The inference about 6GB.

All being well, you will notice that the original Qwen doesn't follow the system prompt at all. The fine tuned model does! And, notably, it DOESN'T use it in an inappropriate situation like telling a joke. (Though it's not perfect, this is a very rough experiment.)

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorWithPadding
import shutil
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
output_dir="./finetuned_model"
shutil.rmtree(output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
SYSTEM_PROMPT = """
Respond in the following format:
<explanation>
...
</explanation>
<code>
...
</code>
"""
examples = [
["Reverse a string in Ruby",
"Use the reverse method on a string to get the reversed version of it.",
"puts 'hello'.reverse # Output: 'olleh'"],
["Convert a string to uppercase in Ruby",
"Use the upcase method on a string to convert it to uppercase.",
"puts 'hello'.upcase # Output: 'HELLO'"],
["Reverse a string in Python",
"Use string slicing with a step of -1 to reverse the string.",
"text = 'hello'\nprint(text[::-1]) # Output: 'olleh'"],
["Convert string to uppercase in Python",
"Use the upper() method to convert a string to uppercase.",
"text = 'hello'\nprint(text.upper()) # Output: 'HELLO'"],
["Reverse a string in JavaScript",
"Split the string into an array of characters, reverse it, and join it back together.",
"const str = 'hello';\nconsole.log(str.split('').reverse().join('')); // Output: 'olleh'"],
["Convert string to uppercase in JavaScript",
"Use the toUpperCase() method to convert a string to uppercase.",
"const str = 'hello';\nconsole.log(str.toUpperCase()); // Output: 'HELLO'"],
["Find maximum value in array in Ruby",
"Use the max method to find the highest value in an array.",
"arr = [1, 5, 3, 8, 2]\nputs arr.max # Output: 8"],
["Find maximum value in array in Python",
"Use the built-in max() function to find the highest value.",
"arr = [1, 5, 3, 8, 2]\nprint(max(arr)) # Output: 8"],
["Find maximum value in array in JavaScript",
"Use Math.max() with the spread operator to find the highest value.",
"const arr = [1, 5, 3, 8, 2];\nconsole.log(Math.max(...arr)); // Output: 8"],
["Find maximum value in array in C",
"Iterate through the array keeping track of the maximum value found.",
"#include <stdio.h>\nint main() {\n int arr[] = {1, 5, 3, 8, 2};\n int max = arr[0];\n for(int i = 1; i < 5; i++) {\n if(arr[i] > max) max = arr[i];\n }\n printf(\"%d\\n\", max); // Output: 8\n return 0;\n}"],
["Check if number is prime in Ruby",
"Create a method to check if a number is only divisible by 1 and itself.",
"def prime?(num)\n return false if num <= 1\n (2..Math.sqrt(num)).none? { |i| num % i == 0 }\nend\n\nputs prime?(17) # Output: true"],
["Check if number is prime in Python",
"Create a function to check if a number has any divisors other than 1 and itself.",
"def is_prime(num):\n if num <= 1: return False\n return all(num % i != 0 for i in range(2, int(num ** 0.5) + 1))\n\nprint(is_prime(17)) # Output: True"],
["Create and use a hash/dictionary in Ruby",
"Create a hash (dictionary) and access its values using keys.",
"person = { 'name' => 'John', 'age' => 30 }\nputs person['name'] # Output: 'John'"],
["Create and use a dictionary in Python",
"Create a dictionary and access its values using keys.",
"person = {'name': 'John', 'age': 30}\nprint(person['name']) # Output: 'John'"],
["Create and use an object in JavaScript",
"Create an object and access its properties.",
"const person = {name: 'John', age: 30};\nconsole.log(person.name); // Output: 'John'"],
["Create and use a struct in C",
"Define and use a struct to group related data.",
"#include <stdio.h>\nstruct Person {\n char name[50];\n int age;\n};\n\nint main() {\n struct Person person = {\"John\", 30};\n printf(\"%s\\n\", person.name); // Output: John\n return 0;\n}"],
["Iterate over array with index in Ruby",
"Use each_with_index to iterate over array elements with their indices.",
"arr = ['a', 'b', 'c']\narr.each_with_index { |elem, i| puts \"#{i}: #{elem}\" }"],
["Iterate over array with index in Python",
"Use enumerate to iterate over array elements with their indices.",
"arr = ['a', 'b', 'c']\nfor i, elem in enumerate(arr):\n print(f\"{i}: {elem}\")"],
["Iterate over array with index in JavaScript",
"Use forEach with arrow function to iterate over array elements.",
"const arr = ['a', 'b', 'c'];\narr.forEach((elem, i) => console.log(`${i}: ${elem}`));"],
["Read file content in Ruby",
"Use File.read to read entire file content into a string.",
"content = File.read('example.txt')\nputs content"],
["Read file content in Python",
"Use with statement and read() to safely read file content.",
"with open('example.txt', 'r') as file:\n content = file.read()\nprint(content)"],
["Read file content in JavaScript (Node.js)",
"Use fs.readFileSync to read file content synchronously.",
"const fs = require('fs');\nconst content = fs.readFileSync('example.txt', 'utf8');\nconsole.log(content);"],
["Basic error handling in Ruby",
"Use begin/rescue blocks to handle potential errors.",
"begin\n # Some risky operation\n 1 / 0\nrescue ZeroDivisionError => e\n puts \"Error: #{e.message}\"\nend"],
["Basic error handling in Python",
"Use try/except blocks to handle potential errors.",
"try:\n # Some risky operation\n 1 / 0\nexcept ZeroDivisionError as e:\n print(f\"Error: {str(e)}\")"],
["Basic error handling in JavaScript",
"Use try/catch blocks to handle potential errors.",
"try {\n // Some risky operation\n throw new Error('Something went wrong');\n} catch (e) {\n console.error(`Error: ${e.message}`);\n}"]
];
data = [
{
"instruction": SYSTEM_PROMPT,
"prompt": ex[0],
"response": f"<explanation>\n{ex[1]}\n</explanation>\n<code>\n{ex[2]}\n</code>\n"
}
for ex in examples
]
def preprocess(example):
text = f"System: {example['instruction']}\nUser: {example['prompt']}\nAssistant: {example['response']}"
tokenized = tokenizer(text, truncation=True, max_length=512, padding=False)
tokenized["labels"] = tokenized["input_ids"][:]
return tokenized
dataset = Dataset.from_list(data).map(preprocess)
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=8,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
learning_rate=1e-5,
warmup_steps=20,
weight_decay=0.01,
logging_steps=10,
save_steps=50,
save_total_limit=1,
fp16=True
)
def custom_data_collator(features):
labels = [f.pop("labels") for f in features]
batch = tokenizer.pad(features, return_tensors="pt")
batch["labels"] = tokenizer.pad({"input_ids": labels}, return_tensors="pt")["input_ids"]
return batch
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=custom_data_collator,
)
trainer.train()
model.save_pretrained("./finetuned_model")
tokenizer.save_pretrained("./finetuned_model")
from transformers import AutoModelForCausalLM, AutoTokenizer, logging
SYSTEM_PROMPT = """
Respond in the following format:
<explanation>
...
</explanation>
<code>
...
</code>
"""
def do_inference(model_name, prompt):
logging.set_verbosity_error()
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
prompts = [
"Write a Ruby program that prints Hello, World to stdout.",
"Tell me a joke.",
"Write a C program that prints poo poo to stdout."
]
models = ['Qwen/Qwen2.5-1.5B-Instruct', 'finetuned_model']
#models = ['finetuned_model']
for model in models:
print("----------------------------------")
print(" " + model)
print("----------------------------------")
for prompt in prompts:
print("")
print(">>> " + prompt)
print("")
print(do_inference(model, prompt))
print("")
@peterc
Copy link
Author

peterc commented Feb 3, 2025

Improvements and suggestions welcomed as I am basically figuring this all out as I go along. For example, the f"System: {example['instruction']}\nUser: {example['prompt']}\nAssistant: {example['response']}" bit feels sketchy to me. It works, but I don't think it really represents the tags used to actually send system messages properly.. unless Qwen's tokenizer takes care of that(??)

@peterc
Copy link
Author

peterc commented Feb 3, 2025

OK, going a few steps further with this, there is definitely some overfitting going on as it doesn't hold up when you move into more general programming questions. Working on it ;-)

@peterc
Copy link
Author

peterc commented Feb 3, 2025

I am hoping to be able to get there with a relatively small number of examples since it's more about the format than the actual content of the questions or answers, but if I need to generate 1000 examples, I will!

@peterc
Copy link
Author

peterc commented Feb 3, 2025

Upping the learning rate and epochs a bit improves things a lot but then it starts to use the format for non-relevant examples too like telling a joke. I probably need to add some examples for those general areas to round it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment