Created
December 27, 2023 05:58
-
-
Save ericflo/70d569d7e2db2a7f6cac924d535b1c99 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
from trl import SFTTrainer | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TrainingArguments, | |
HfArgumentParser, | |
) | |
from peft import LoraConfig | |
import torch | |
def make_formatting_func(template_tokenizer_name): | |
tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_name) | |
def inner(example): | |
return tokenizer.apply_chat_template(example["messages"], tokenize=False) | |
return inner | |
def main( | |
template_tokenizer_name="teknium/OpenHermes-2.5-Mistral-7B", | |
model_name="mistralai/Mistral-7B-v0.1", | |
dataset_name="ericflo/unnaturalhermes-reflections-100k", | |
context_length=32768, | |
): | |
parser = HfArgumentParser(TrainingArguments) | |
training_args = parser.parse_args_into_dataclasses()[0] | |
full_dataset = load_dataset(dataset_name, split="train") | |
filtered_dataset = full_dataset.filter( | |
lambda row: row["metadata"]["prompt_version"] == 3 | |
and "ixtral" in row["metadata"]["model"] | |
) | |
dataset = filtered_dataset.train_test_split(test_size=500).with_format("torch") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2", | |
device_map="auto", | |
) | |
formatting_func = make_formatting_func(template_tokenizer_name) | |
peft_config = LoraConfig( | |
r=64, | |
lora_alpha=256, | |
lora_dropout=0.05, | |
target_modules=["gate_proj", "down_proj", "up_proj"], | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
trainer = SFTTrainer( | |
model, | |
args=training_args, | |
train_dataset=dataset["train"], | |
eval_dataset=dataset["test"], | |
formatting_func=formatting_func, | |
max_seq_length=context_length, | |
peft_config=peft_config, | |
packing=True, | |
) | |
trainer.train() | |
trainer.save_model("final") | |
if __name__ == "__main__": | |
""" | |
python train.py \ | |
--output_dir mistral-7b-reflect \ | |
--report_to wandb \ | |
--bf16 True \ | |
--gradient_checkpointing True \ | |
--per_device_train_batch_size 1 \ | |
--per_device_eval_batch_size 1 \ | |
--logging_steps 1 \ | |
--do_eval True \ | |
--evaluation_strategy steps \ | |
--eval_steps 20 | |
""" | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment