Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Created April 22, 2025 23:39
Show Gist options
  • Save justinchuby/d0847a2edc61c628c224a3bfb6a9e8e6 to your computer and use it in GitHub Desktop.
Save justinchuby/d0847a2edc61c628c224a3bfb6a9e8e6 to your computer and use it in GitHub Desktop.
Export HF models with torch.onnx
import torch
from onnx_diagnostic import torch_export_patches
from onnxscript.ir.passes.common import clear_metadata_and_docstring
from transformers import AttentionInterface, AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
# Get position_ids from attention_mask
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if use_past_kv:
# Shape: (batch_size, 1)
position_ids = position_ids[:, -1].unsqueeze(-1)
# Shape: (batch_size, sequence_length)
return position_ids
def get_cos_sin_cache():
position_ids = torch.arange(model.config.max_position_embeddings).unsqueeze(0)
x = torch.tensor([], dtype=torch_dtype, device="cpu")
cos, sin = model.model.rotary_emb(x, position_ids)
return cos, sin
def make_dynamic_cache(key_value_pairs) -> DynamicCache:
"""
Creates an instance of :class:`DynamicCache`.
This version is valid for ``transformers``.
:param key_value_pairs: list of pairs of (key, values)
:return: :class:`transformers.cache_utils.DynamicCache`
"""
cache = DynamicCache()
for i, (key, value) in enumerate(key_value_pairs):
cache.update(key, value, i)
return cache
def get_input_cache(
num_hidden_layers: int,
batch_size: int,
num_key_value_heads: int,
sequence_length: int,
head_dim: int,
device: str,
) -> DynamicCache:
cache = make_dynamic_cache(
[
(
torch.zeros(
(batch_size, num_key_value_heads, sequence_length, head_dim)
).to(device),
torch.zeros(
(batch_size, num_key_value_heads, sequence_length, head_dim)
).to(device),
)
for i in range(num_hidden_layers)
]
)
return cache
def attention_forward(
self,
hidden_states,
position_embeddings,
attention_mask,
past_key_value,
cache_position,
**kwargs,
):
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
# cos, sin = position_embeddings
attn_output = torch.onnx.ops.symbolic(
"com.microsoft::GroupQueryAttention",
(
query,
key,
value,
past_key_value[self.layer_idx][0],
past_key_value[self.layer_idx][1],
None,
None,
cos,
sin,
),
attrs={"scaling": self.scaling},
dtype=query.dtype,
shape=query.shape,
version=1,
)
attn_output = self.o_proj(attn_output)
return attn_output, torch.tensor(0, dtype=query.dtype)
def get_dummy_inputs_and_shapes():
inputs = tokenizer(
["Hello, how are you today?", "What color is the sky?"],
padding=True,
return_tensors="pt",
)
dummy_inputs = (
inputs["input_ids"],
inputs["attention_mask"],
get_position_ids(inputs["attention_mask"], use_past_kv=False),
)
# dynamic_shapes
shapes = {}
batch = torch.export.Dim("batch", min=1, max=1024)
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
shapes.update(
{
"input_ids": {0: batch, 1: seq_length},
"attention_mask": {
0: batch,
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
},
"position_ids": {0: batch, 1: seq_length},
}
)
batch_size, sequence_length = inputs["input_ids"].shape
cache = get_input_cache(
model.config.num_hidden_layers,
batch_size,
model.config.num_key_value_heads,
sequence_length,
getattr(
model.config,
"head_dim",
model.config.hidden_size // model.config.num_attention_heads,
),
"cpu",
)
n = len(cache.key_cache)
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
shapes.update(
{
"past_key_values": [
[{0: batch, 2: cache_length} for _ in range(n)],
[{0: batch, 2: cache_length} for _ in range(n)],
],
}
)
dummy_inputs = (*dummy_inputs, cache)
return dummy_inputs, shapes
if __name__ == "__main__":
cache_dir = "/workspace/cache_dir/"
model_name = "meta-llama/Llama-3.2-1B-Instruct"
torch_dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
model_name, attn_implementation="eager", cache_dir=cache_dir
).to(torch_dtype)
cos, sin = get_cos_sin_cache()
for module in model.modules():
if "attention" in module.__class__.__name__.lower():
module.__class__.forward = attention_forward
dummy_inputs, shapes = get_dummy_inputs_and_shapes()
with (
torch.no_grad(),
torch_export_patches.bypass_export_some_errors(patch_transformers=True),
):
onnx_program = torch.onnx.export(
model,
args=dummy_inputs,
dynamic_shapes=shapes,
dynamo=True,
optimize=True,
)
model = onnx_program.model
clear_metadata_and_docstring.ClearMetadataAndDocStringPass()(model)
onnx_program.save("model.onnx")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment