from datasets import load_dataset from trl import SFTTrainer from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, HfArgumentParser, ) from peft import LoraConfig import torch def make_formatting_func(template_tokenizer_name): tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_name) def inner(example): return tokenizer.apply_chat_template(example["messages"], tokenize=False) return inner def main( template_tokenizer_name="teknium/OpenHermes-2.5-Mistral-7B", model_name="mistralai/Mistral-7B-v0.1", dataset_name="ericflo/unnaturalhermes-reflections-100k", context_length=32768, ): parser = HfArgumentParser(TrainingArguments) training_args = parser.parse_args_into_dataclasses()[0] full_dataset = load_dataset(dataset_name, split="train") filtered_dataset = full_dataset.filter( lambda row: row["metadata"]["prompt_version"] == 3 and "ixtral" in row["metadata"]["model"] ) dataset = filtered_dataset.train_test_split(test_size=500).with_format("torch") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto", ) formatting_func = make_formatting_func(template_tokenizer_name) peft_config = LoraConfig( r=64, lora_alpha=256, lora_dropout=0.05, target_modules=["gate_proj", "down_proj", "up_proj"], bias="none", task_type="CAUSAL_LM", ) trainer = SFTTrainer( model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], formatting_func=formatting_func, max_seq_length=context_length, peft_config=peft_config, packing=True, ) trainer.train() trainer.save_model("final") if __name__ == "__main__": """ python train.py \ --output_dir mistral-7b-reflect \ --report_to wandb \ --bf16 True \ --gradient_checkpointing True \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --logging_steps 1 \ --do_eval True \ --evaluation_strategy steps \ --eval_steps 20 """ main()