Created
February 24, 2024 04:23
-
-
Save alfredplpl/e20cad036c151f38645a1abc87f56a2f to your computer and use it in GitHub Desktop.
Gemma初心者ファインチューニングコードです。HFの設定などはよしなにやってください。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reference #1: https://note.com/npaka/n/nc55e44e407ff | |
# Reference #2: https://huggingface.co/blog/gemma-peft | |
# Licence: MIT | |
from peft import LoraConfig | |
lora_config = LoraConfig( | |
r=8, | |
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], | |
task_type="CAUSAL_LM", | |
) | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
model_id = "google/gemma-2b-it" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
import os | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN']) | |
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN']) | |
from datasets import load_dataset | |
# データセットの読み込み | |
dataset = load_dataset("bbz662bbz/databricks-dolly-15k-ja-gozarinnemon", split="train") | |
dataset = dataset.filter(lambda example: example["category"] == "open_qa") | |
# プロンプトの生成 | |
def generate_prompt(example): | |
return """<bos><start_of_turn>user | |
{}<end_of_turn> | |
<start_of_turn>model | |
{}<eos>""".format(example["instruction"], example["output"]) | |
# textカラムの追加 | |
def add_text(example): | |
example["text"] = generate_prompt(example) | |
return example | |
dataset = dataset.map(add_text) | |
dataset = dataset.remove_columns(["input", "category", "output", "index", "instruction"]) | |
# データセットの分割 | |
train_test_split = dataset.train_test_split(test_size=0.1) | |
train_dataset = train_test_split["train"] | |
eval_dataset = train_test_split["test"] | |
import transformers | |
from trl import SFTTrainer | |
trainer = SFTTrainer( | |
model=model, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
args=transformers.TrainingArguments( | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=4, | |
warmup_steps=10, | |
max_steps=1000, | |
learning_rate=2e-4, | |
fp16=True, | |
logging_steps=50, | |
output_dir="outputs", | |
optim="paged_adamw_8bit" | |
), | |
peft_config=lora_config, | |
dataset_text_field="text" | |
) | |
trainer.train() | |
# trainer.save_model("/path/to/model") | |
# プロンプトの準備 | |
prompt="""<start_of_turn>user | |
猫と犬、どっちが好き?<end_of_turn> | |
<start_of_turn>model | |
""" | |
# 推論の実行 | |
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**input_ids, | |
max_new_tokens=128, | |
do_sample=True, | |
top_p=0.95, | |
temperature=0.2, | |
repetition_penalty=1.1, | |
) | |
print(tokenizer.decode(outputs[0])) | |
# <bos><start_of_turn>user | |
# 猫と犬、どっちが好き?<end_of_turn> | |
# <start_of_turn>model | |
# 我、りんえもんは思う。 猫と犬はどちらも素晴らしい動物でござる。猫は、犬よりもより静かな動物で、犬よりもより行動的で、猫よりもより小さな動物でござる。犬は、猫よりもより行動的で、猫よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment