Created
July 8, 2024 10:37
yarn checks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Assumes: | |
1. transformers on this branch (https://github.com/huggingface/transformers/pull/30910) | |
2. yarn pip installed (https://github.com/jquesnelle/yarn) | |
3. HF login with read token (`huggingface-cli login`) | |
""" | |
import torch | |
from huggingface_hub import hf_hub_download | |
from transformers import AutoConfig, AutoTokenizer | |
from transformers.models.llama.modeling_llama import LlamaYarnScalingRotaryEmbedding, LlamaDynamicYarnScalingRotaryEmbedding | |
from scaled_rope.LlamaYaRNScaledRotaryEmbedding import LlamaYaRNScaledRotaryEmbedding | |
from scaled_rope.LlamaDynamicYaRNScaledRotaryEmbedding import LlamaDynamicYaRNScaledRotaryEmbedding | |
model_id="meta-llama/Meta-Llama-3-8B" | |
filenames = ["config.json", "generation_config.json", "model-00001-of-00004.safetensors", "model-00002-of-00004.safetensors", | |
"model-00003-of-00004.safetensors", "model-00004-of-00004.safetensors", "model.safetensors.index.json", | |
"special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"] | |
for filename in filenames: | |
downloaded_model_path = hf_hub_download(repo_id=model_id, filename=filename) | |
print(downloaded_model_path) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") | |
model_config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B") | |
def generate_hf_embeddings(input_text, dim, tokenizer, method="yarn"): | |
inputs = tokenizer(input_text, return_tensors="pt") | |
position_ids = torch.ones_like(inputs.input_ids).cumsum(dim=1) - 1 | |
# the embeddings only need the right dtype and device from `x`, the input | |
dummy_input = torch.ones_like(inputs.input_ids, dtype=torch.float32) | |
if method == "yarn": | |
embedding = LlamaYarnScalingRotaryEmbedding(dim=dim) | |
elif method == "dynamic_yarn": | |
embedding = LlamaDynamicYarnScalingRotaryEmbedding(dim=dim) | |
else: | |
raise ValueError("Invalid method specified") | |
return embedding(dummy_input, position_ids) | |
def generate_yarn_embeddings(input_text, dim, tokenizer, method="yarn"): | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# the embeddings only need the right dtype and device from `x`, the input | |
dummy_input = torch.ones_like(inputs.input_ids, dtype=torch.float32) | |
if method == "yarn": | |
embedding = LlamaYaRNScaledRotaryEmbedding(dim=dim) | |
elif method == "dynamic_yarn": | |
embedding = LlamaDynamicYaRNScaledRotaryEmbedding(dim=dim) | |
else: | |
raise ValueError("Invalid method specified") | |
seq_len = inputs.input_ids.size(1) | |
return embedding(dummy_input, seq_len=seq_len) | |
input_text = "This is a large test input. " * 1200 # sequence length > 8k | |
dim = model_config.hidden_size // model_config.num_attention_heads | |
hf_yarn_embeddings = generate_hf_embeddings(input_text, dim, tokenizer, method="yarn") | |
hf_yarn_embeddings_cos = hf_yarn_embeddings[0][0] | |
hf_yarn_embeddings_sin = hf_yarn_embeddings[1][0] | |
hf_dynamic_yarn_embeddings = generate_hf_embeddings(input_text, dim, tokenizer, method="dynamic_yarn") | |
hf_dynamic_yarn_embeddings_cos = hf_dynamic_yarn_embeddings[0][0] | |
hf_dynamic_yarn_embeddings_sin = hf_dynamic_yarn_embeddings[1][0] | |
yarn_embeddings = generate_yarn_embeddings(input_text, dim, tokenizer, method="yarn") | |
yarn_embeddings_cos = yarn_embeddings[0][0, 0] | |
yarn_embeddings_sin = yarn_embeddings[1][0, 0] | |
dynamic_yarn_embeddings = generate_yarn_embeddings(input_text, dim, tokenizer, method="dynamic_yarn") | |
dynamic_yarn_embeddings_cos = dynamic_yarn_embeddings[0][0, 0] | |
dynamic_yarn_embeddings_sin = dynamic_yarn_embeddings[1][0, 0] | |
assert torch.allclose(hf_yarn_embeddings_cos, yarn_embeddings_cos), "Yarn embeddings do not match!" | |
assert torch.allclose(hf_yarn_embeddings_sin, yarn_embeddings_sin), "Yarn embeddings do not match!" | |
assert torch.allclose(hf_dynamic_yarn_embeddings_cos, dynamic_yarn_embeddings_cos), "Dynamic Yarn embeddings do not match!" | |
assert torch.allclose(hf_dynamic_yarn_embeddings_sin, dynamic_yarn_embeddings_sin), "Dynamic Yarn embeddings do not match!" | |
print("Embeddings match successfully!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment