Last active
May 7, 2024 16:28
-
-
Save raphael-sch/c38c4dbaecce62566bc3f07559678360 to your computer and use it in GitHub Desktop.
Training and position_ids with left padding
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import transformers | |
import torch | |
from datasets import Dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
parser = argparse.ArgumentParser(description='Define experiment parameters') | |
parser.add_argument('--use_custom_position_ids', default='no', choices=['yes', 'no'], type=str) | |
parser.add_argument('--model_name', default='meta-llama/Llama-2-7b-hf', type=str) | |
parser.add_argument('--hf_auth_token', default=None, type=str) | |
parser.add_argument('--seed', default=1111, type=int) | |
opts = parser.parse_args() | |
torch.cuda.manual_seed_all(opts.seed) | |
torch.backends.cudnn.deterministic = True | |
# different loss | |
# GPT-Neo | |
# llama | |
# same loss | |
# opt # correct padding_id creation | |
# bloom # relative embeddings | |
def main(): | |
tokenizer = AutoTokenizer.from_pretrained(opts.model_name, token=opts.hf_auth_token) | |
model = AutoModelForCausalLM.from_pretrained(opts.model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
token=opts.hf_auth_token | |
) | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
tokenizer.padding_side = 'left' | |
texts = [ | |
dict(text="This is a sentence."), | |
dict(text="This is a much longer sentence.") | |
] | |
def tokenize(d): | |
inputs = tokenizer(d['text']) | |
inputs['labels'] = inputs['input_ids'] | |
return inputs | |
train_dataset = Dataset.from_list(texts) | |
train_dataset = train_dataset.map(tokenize) | |
train_args = transformers.TrainingArguments( | |
per_device_train_batch_size=2, | |
num_train_epochs=1, | |
output_dir='./', | |
save_strategy='no', | |
seed=opts.seed, | |
data_seed=opts.seed | |
) | |
trainer = CustomTrainer( | |
use_custom_position_ids=opts.use_custom_position_ids == 'yes', | |
model=model, | |
train_dataset=train_dataset, | |
args=train_args, | |
data_collator=transformers.DataCollatorForSeq2Seq(tokenizer, padding=True) | |
) | |
out = trainer.train() | |
print('Training Loss:', out.metrics['train_loss']) | |
class CustomTrainer(transformers.Trainer): | |
def __init__(self, use_custom_position_ids=False, **kwargs): | |
super().__init__(**kwargs) | |
self.use_custom_position_ids = use_custom_position_ids | |
def compute_loss(self, model, inputs, return_outputs=False): | |
input_ids = inputs['input_ids'] | |
print('padded input_ids:', input_ids.tolist()) | |
if self.use_custom_position_ids: | |
position_ids = list() | |
for _input_ids in input_ids: | |
_position_ids = list() | |
position_id = 0 | |
for token_id in _input_ids: | |
_position_ids.append(position_id) | |
if token_id != self.data_collator.tokenizer.pad_token_id: | |
position_id += 1 | |
position_ids.append(_position_ids) | |
print('Create custom position_ids with appropriate padding:') | |
print(position_ids) | |
inputs['position_ids'] = torch.tensor(position_ids).to(model.device) | |
else: | |
assert 'position_ids' not in inputs | |
print('There are no position_ids in the input.') | |
print('The transformers model implementation will create position_ids based on the max_length of the batch.') | |
max_length = input_ids.shape[-1] | |
position_ids = [list(range(max_length)) for _ in input_ids] | |
print(position_ids) | |
print('They are incorrect with left padding!') | |
return super().compute_loss(model, inputs, return_outputs=return_outputs) | |
if __name__ == '__main__': | |
main() |
If the tokenizer returns the attention mask the position ids could be generated with cumulative sum as
attention_mask = torch.tensor([[0,0,0,0,1,1,1], [0,0,0,1,1,1,1]])
position_ids = attention_mask.cumsum(dim=1)
print(position_ids) # tensor([[0, 0, 0, 0, 1, 2, 3],[0, 0, 0, 1, 2, 3, 4]])
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
run with
and
and compare training loss