Last active
March 21, 2024 14:49
-
-
Save raphael-sch/479289637e6f4138242c9caea0dc9b44 to your computer and use it in GitHub Desktop.
Using padding and prefill during inference in huggingface transformers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import sys | |
import time | |
import tqdm | |
import torch | |
from datasets import load_dataset, concatenate_datasets | |
from transformers import AutoTokenizer, LlamaForCausalLM | |
# torch==2.0.1 | |
# transformers==4.34.0 | |
# tokenizers==0.14.0 | |
# python: 3.10 | |
# GPU: NVIDIA RTX A6000 | |
# CUDA: 11.8 | |
def main(): | |
hf_auth_token = '' | |
tokens_per_batch = 20000 | |
max_new_tokens = 5 | |
for use_flash_attention_2 in [False, True]: | |
model, tokenizer = get_model(hf_auth_token, use_flash_attention_2=use_flash_attention_2) | |
instances = get_instances(tokenizer, tokens_per_batch, max_new_tokens) | |
print('\nnumber of instances:', len(instances)) | |
print(f'\npadding False, prefill False, flash_attention_2 {use_flash_attention_2}') | |
start_time = time.time() | |
batches = get_batches_equal_length(instances, tokens_per_batch) | |
prefill = None | |
for batch in tqdm.tqdm(batches): | |
inference(model, tokenizer, batch, prefill, max_new_tokens) | |
print('duration in seconds:', int(time.time() - start_time)) | |
accuracy = evaluate(instances) | |
print('accuracy:', accuracy) | |
print(f'\npadding False, prefill True, flash_attention_2 {use_flash_attention_2}') | |
start_time = time.time() | |
batches = get_batches_equal_length(instances, tokens_per_batch) | |
prefill = get_prefill_input_ids(instances, model) | |
for batch in tqdm.tqdm(batches): | |
inference(model, tokenizer, batch, prefill, max_new_tokens) | |
print('duration in seconds:', int(time.time() - start_time)) | |
accuracy = evaluate(instances) | |
print('accuracy:', accuracy) | |
print(f'\npadding True, prefill False, flash_attention_2 {use_flash_attention_2}') | |
start_time = time.time() | |
batches = get_batches(instances, tokens_per_batch) | |
prefill = None | |
for batch in tqdm.tqdm(batches): | |
inference(model, tokenizer, batch, prefill, max_new_tokens) | |
print('duration in seconds:', int(time.time() - start_time)) | |
accuracy = evaluate(instances) | |
print('accuracy:', accuracy) | |
print(f'\npadding True, prefill True, flash_attention_2 {use_flash_attention_2}') | |
start_time = time.time() | |
batches = get_batches(instances, tokens_per_batch) | |
prefill = get_prefill_input_ids(instances, model) | |
for batch in tqdm.tqdm(batches): | |
inference(model, tokenizer, batch, prefill, max_new_tokens) | |
print('duration in seconds:', int(time.time() - start_time)) | |
accuracy = evaluate(instances) | |
print('accuracy:', accuracy) | |
def get_padded_inputs(list_of_input_ids, tokenizer, model, prefill=None): | |
if prefill is not None: | |
prefill_input_ids, prefill_key_values = prefill | |
num_prefill_tokens = len(prefill_input_ids) | |
list_of_input_ids = [input_ids[num_prefill_tokens:] for input_ids in list_of_input_ids] | |
max_length = max(len(input_ids) for input_ids in list_of_input_ids) | |
padded_list_of_input_ids = list() | |
attention_mask = list() | |
position_ids = list() | |
for input_ids in list_of_input_ids: | |
num_pad_tokens = max_length - len(input_ids) | |
padded_input_ids = [tokenizer.pad_token_id] * num_pad_tokens + input_ids | |
if prefill is not None: | |
# position ids are only needed for "non-prefill" input ids and are offset by the number of prefill tokens | |
_position_ids = [i + len(prefill_input_ids) for i in range(len(input_ids))] | |
padded_position_ids = [1] * num_pad_tokens + _position_ids | |
# attention mask is active for prefill tokens, not active for padding tokens, and again active for new input tokens | |
padded_attention_mask = [1] * len(prefill_input_ids) + [0] * num_pad_tokens + [1] * len(input_ids) | |
else: | |
padded_position_ids = [1] * num_pad_tokens + list(range(len(input_ids))) | |
padded_attention_mask = [0] * num_pad_tokens + [1] * len(input_ids) | |
assert len(padded_position_ids) == len(padded_input_ids) | |
padded_list_of_input_ids.append(padded_input_ids) | |
position_ids.append(padded_position_ids) | |
attention_mask.append(padded_attention_mask) | |
inputs = dict() | |
inputs['input_ids'] = torch.as_tensor(padded_list_of_input_ids).to(model.device) | |
inputs['position_ids'] = torch.as_tensor(position_ids).to(model.device) | |
inputs['attention_mask'] = torch.as_tensor(attention_mask).to(model.device) | |
inputs['past_key_values'] = None | |
if prefill is not None: | |
prefill_input_ids, prefill_key_values = prefill | |
# adapt cache to current batch size | |
past_key_values = get_tiled_cache(prefill_key_values, batch_size=len(list_of_input_ids)) | |
inputs['past_key_values'] = past_key_values | |
return inputs | |
def get_tiled_cache(prefill_key_values, batch_size): | |
past_key_values = list() | |
for layer_cache in prefill_key_values: | |
key_cache = layer_cache[0] | |
value_cache = layer_cache[1] | |
assert key_cache.shape[0] == 1 | |
assert value_cache.shape[0] == 1 | |
key_cache = torch.tile(key_cache, (batch_size, 1, 1, 1)) | |
value_cache = torch.tile(value_cache, (batch_size, 1, 1, 1)) | |
past_key_values.append((key_cache, value_cache)) | |
past_key_values = tuple(past_key_values) | |
return past_key_values | |
def update_generation_inputs(inputs, next_token_ids, past_key_values=None): | |
inputs['input_ids'] = next_token_ids | |
attention_mask_cat = torch.ones_like(inputs['input_ids']) | |
inputs['attention_mask'] = torch.cat((inputs['attention_mask'], attention_mask_cat), dim=-1) | |
inputs['position_ids'] = (inputs['position_ids'][:, -1] + 1).unsqueeze(1) | |
inputs['past_key_values'] = past_key_values | |
return inputs | |
def get_instances(tokenizer, tokens_per_batch, max_new_tokens=5): | |
prompt = """<s>[INST] <<SYS>> | |
Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity. | |
<</SYS>> | |
{user_msg_1} [/INST] {model_answer_1_start}""" | |
instances = list() | |
dataset1 = load_dataset("EleutherAI/asdiv") | |
dataset2 = load_dataset("ChilleD/SVAMP") | |
number_pattern = re.compile(r"-?(?<!\d)\d{1,10}(?:,\d{3})*(?:\.\d+)?%?(?!\d)") | |
for instance in concatenate_datasets([dataset1['validation'], dataset2['train'], dataset2['test']]): | |
answer = str(instance.get('answer', instance['Answer'])) | |
numbers = number_pattern.findall(answer) | |
if len(numbers) == 1: | |
body = instance.get('body', instance['Body']) | |
question = instance.get('question', instance['Question']) | |
question = body + ' ' + question + '\nPlease respond only with the result.' | |
instance['text'] = prompt.format(user_msg_1=question, model_answer_1_start='The answer is: ') | |
instance['label'] = round(float(numbers[0].replace(",", "")), 2) | |
instance['input_ids'] = tokenizer.encode(instance['text'], add_special_tokens=False) | |
instance['num_tokens'] = len(instance['input_ids']) + max_new_tokens | |
assert instance['num_tokens'] <= tokens_per_batch | |
instances.append(instance) | |
return instances | |
def get_model(hf_auth_token, model_name='meta-llama/Llama-2-7b-chat-hf', use_flash_attention_2=False): | |
if len(sys.argv) > 1: | |
hf_auth_token = sys.argv[1] | |
if hf_auth_token == '': | |
raise ValueError('Please provide your huggingface auth key in the script or as first command line argument') | |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_auth_token) | |
model = LlamaForCausalLM.from_pretrained(model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
token=hf_auth_token, | |
use_flash_attention_2=use_flash_attention_2 | |
) | |
tokenizer.padding_side = 'left' | |
model.config.pad_token_id = tokenizer.pad_token_id = 0 | |
model.config.bos_token_id = tokenizer.bos_token_id = 1 | |
model.config.eos_token_id = tokenizer.eos_token_id = 2 | |
return model, tokenizer | |
def get_prefill_input_ids(instances, model): | |
assert len(instances) > 1 | |
num_prefill_tokens = None | |
for i, value in enumerate(instances[0]['input_ids']): | |
if any(ins['input_ids'][i] != value for ins in instances): | |
num_prefill_tokens = i | |
break | |
assert num_prefill_tokens is not None | |
prefill_input_ids = instances[0]['input_ids'][:num_prefill_tokens] | |
with torch.no_grad(): | |
input_ids = torch.tensor([prefill_input_ids]) | |
outputs = model(input_ids=input_ids.to(model.device)) | |
past_key_values = outputs.past_key_values | |
return prefill_input_ids, past_key_values | |
def get_batches_equal_length(instances, tokens_per_batch): | |
instances = list(sorted(instances, key=lambda ins: -ins['num_tokens'])) | |
batches = list() | |
batch = [instances[0]] | |
num_tokens = instances[0]['num_tokens'] | |
for instance in instances[1:]: | |
if instance['num_tokens'] == num_tokens and num_tokens * (len(batch) + 1) <= tokens_per_batch: | |
batch.append(instance) | |
else: | |
batches.append(batch) | |
batch = [instance] | |
num_tokens = instance['num_tokens'] | |
batches.append(batch) | |
return batches | |
def get_batches(instances, tokens_per_batch): | |
instances = list(sorted(instances, key=lambda ins: -ins['num_tokens'])) | |
batches = list() | |
batch = [instances[0]] | |
num_tokens = instances[0]['num_tokens'] | |
for instance in instances[1:]: | |
if num_tokens * (len(batch) + 1) <= tokens_per_batch: | |
batch.append(instance) | |
else: | |
batches.append(batch) | |
batch = [instance] | |
num_tokens = instance['num_tokens'] | |
batches.append(batch) | |
return batches | |
def inference(model, tokenizer, batch, prefill=None, max_new_tokens=5): | |
input_ids = [instance['input_ids'] for instance in batch] | |
inputs = get_padded_inputs(input_ids, tokenizer, model, prefill) | |
generated_sequences = [list() for _ in batch] | |
is_finished = [False for _ in batch] | |
for _ in range(max_new_tokens): | |
with torch.no_grad(): | |
outputs = model(**inputs, use_cache=True) | |
past_key_values = outputs.past_key_values | |
next_token_logits = outputs.logits[:, -1, :] | |
next_token_ids = torch.argmax(next_token_logits, dim=-1) | |
for output_id, instance in enumerate(batch): | |
if is_finished[output_id]: | |
continue | |
next_token_id = next_token_ids[output_id].item() | |
if next_token_id == tokenizer.eos_token_id: | |
is_finished[output_id] = True | |
continue | |
generated_sequences[output_id].append(next_token_id) | |
inputs = update_generation_inputs(inputs=inputs, | |
next_token_ids=next_token_ids.unsqueeze(-1), | |
past_key_values=past_key_values) | |
if all(is_finished): | |
break | |
for instance, generated_sequence in zip(batch, generated_sequences): | |
output_text = tokenizer.decode(generated_sequence, skip_special_tokens=True) | |
instance['response'] = output_text | |
def evaluate(instances): | |
number_pattern = re.compile(r"-?(?<!\d)\d{1,10}(?:,\d{3})*(?:\.\d+)?%?(?!\d)") | |
correct = 0 | |
for instance in instances: | |
response = instance['response'] | |
pred = number_pattern.findall(response) | |
if len(pred) > 0: | |
pred = round(float(pred[0].replace(",", "")), 2) | |
if pred == instance['label']: | |
correct += 1 | |
accuracy = correct / len(instances) * 100 | |
return round(accuracy, 1) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
With the right attention mask and position_ids you can use padding and prefill tokens in huggingface transformers. This speeds up batched inference, especially if each instance has the same system prompt prepended.
Run with
number of instances: 2089
padding False, prefill False, flash_attention_2 False
duration in seconds: 56
accuracy: 42.5
padding False, prefill True, flash_attention_2 False
duration in seconds: 35
accuracy: 42.5
padding True, prefill False, flash_attention_2 False
duration in seconds: 49
accuracy: 42.5
padding True, prefill True, flash_attention_2 False
duration in seconds: 27
accuracy: 42.5
Also works for flash attention, although I don't see additional speed ups.
padding False, prefill False, flash_attention_2 True
duration in seconds: 57
accuracy: 42.5
padding False, prefill True, flash_attention_2 True
duration in seconds: 35
accuracy: 42.5
padding True, prefill False, flash_attention_2 True
duration in seconds: 48
accuracy: 42.5
padding True, prefill True, flash_attention_2 True
duration in seconds: 27
accuracy: 42.5