from datasets import load_dataset
from trl import SFTTrainer
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    HfArgumentParser,
)
from peft import LoraConfig
import torch


def make_formatting_func(template_tokenizer_name):
    tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_name)

    def inner(example):
        return tokenizer.apply_chat_template(example["messages"], tokenize=False)

    return inner


def main(
    template_tokenizer_name="teknium/OpenHermes-2.5-Mistral-7B",
    model_name="mistralai/Mistral-7B-v0.1",
    dataset_name="ericflo/unnaturalhermes-reflections-100k",
    context_length=32768,
):
    parser = HfArgumentParser(TrainingArguments)
    training_args = parser.parse_args_into_dataclasses()[0]
    full_dataset = load_dataset(dataset_name, split="train")
    filtered_dataset = full_dataset.filter(
        lambda row: row["metadata"]["prompt_version"] == 3
        and "ixtral" in row["metadata"]["model"]
    )
    dataset = filtered_dataset.train_test_split(test_size=500).with_format("torch")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        device_map="auto",
    )
    formatting_func = make_formatting_func(template_tokenizer_name)
    peft_config = LoraConfig(
        r=64,
        lora_alpha=256,
        lora_dropout=0.05,
        target_modules=["gate_proj", "down_proj", "up_proj"],
        bias="none",
        task_type="CAUSAL_LM",
    )
    trainer = SFTTrainer(
        model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        formatting_func=formatting_func,
        max_seq_length=context_length,
        peft_config=peft_config,
        packing=True,
    )
    trainer.train()
    trainer.save_model("final")


if __name__ == "__main__":
    """
python train.py \
    --output_dir mistral-7b-reflect \
    --report_to wandb \
    --bf16 True \
    --gradient_checkpointing True \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --logging_steps 1 \
    --do_eval True \
    --evaluation_strategy steps \
    --eval_steps 20
    """
    main()