Created
April 22, 2025 23:39
-
-
Save justinchuby/d0847a2edc61c628c224a3bfb6a9e8e6 to your computer and use it in GitHub Desktop.
Export HF models with torch.onnx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from onnx_diagnostic import torch_export_patches | |
from onnxscript.ir.passes.common import clear_metadata_and_docstring | |
from transformers import AttentionInterface, AutoModelForCausalLM, AutoTokenizer | |
from transformers.cache_utils import DynamicCache | |
# Get position_ids from attention_mask | |
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if use_past_kv: | |
# Shape: (batch_size, 1) | |
position_ids = position_ids[:, -1].unsqueeze(-1) | |
# Shape: (batch_size, sequence_length) | |
return position_ids | |
def get_cos_sin_cache(): | |
position_ids = torch.arange(model.config.max_position_embeddings).unsqueeze(0) | |
x = torch.tensor([], dtype=torch_dtype, device="cpu") | |
cos, sin = model.model.rotary_emb(x, position_ids) | |
return cos, sin | |
def make_dynamic_cache(key_value_pairs) -> DynamicCache: | |
""" | |
Creates an instance of :class:`DynamicCache`. | |
This version is valid for ``transformers``. | |
:param key_value_pairs: list of pairs of (key, values) | |
:return: :class:`transformers.cache_utils.DynamicCache` | |
""" | |
cache = DynamicCache() | |
for i, (key, value) in enumerate(key_value_pairs): | |
cache.update(key, value, i) | |
return cache | |
def get_input_cache( | |
num_hidden_layers: int, | |
batch_size: int, | |
num_key_value_heads: int, | |
sequence_length: int, | |
head_dim: int, | |
device: str, | |
) -> DynamicCache: | |
cache = make_dynamic_cache( | |
[ | |
( | |
torch.zeros( | |
(batch_size, num_key_value_heads, sequence_length, head_dim) | |
).to(device), | |
torch.zeros( | |
(batch_size, num_key_value_heads, sequence_length, head_dim) | |
).to(device), | |
) | |
for i in range(num_hidden_layers) | |
] | |
) | |
return cache | |
def attention_forward( | |
self, | |
hidden_states, | |
position_embeddings, | |
attention_mask, | |
past_key_value, | |
cache_position, | |
**kwargs, | |
): | |
query = self.q_proj(hidden_states) | |
key = self.k_proj(hidden_states) | |
value = self.v_proj(hidden_states) | |
# cos, sin = position_embeddings | |
attn_output = torch.onnx.ops.symbolic( | |
"com.microsoft::GroupQueryAttention", | |
( | |
query, | |
key, | |
value, | |
past_key_value[self.layer_idx][0], | |
past_key_value[self.layer_idx][1], | |
None, | |
None, | |
cos, | |
sin, | |
), | |
attrs={"scaling": self.scaling}, | |
dtype=query.dtype, | |
shape=query.shape, | |
version=1, | |
) | |
attn_output = self.o_proj(attn_output) | |
return attn_output, torch.tensor(0, dtype=query.dtype) | |
def get_dummy_inputs_and_shapes(): | |
inputs = tokenizer( | |
["Hello, how are you today?", "What color is the sky?"], | |
padding=True, | |
return_tensors="pt", | |
) | |
dummy_inputs = ( | |
inputs["input_ids"], | |
inputs["attention_mask"], | |
get_position_ids(inputs["attention_mask"], use_past_kv=False), | |
) | |
# dynamic_shapes | |
shapes = {} | |
batch = torch.export.Dim("batch", min=1, max=1024) | |
seq_length = torch.export.Dim("seq_length", min=1, max=4096) | |
shapes.update( | |
{ | |
"input_ids": {0: batch, 1: seq_length}, | |
"attention_mask": { | |
0: batch, | |
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length | |
}, | |
"position_ids": {0: batch, 1: seq_length}, | |
} | |
) | |
batch_size, sequence_length = inputs["input_ids"].shape | |
cache = get_input_cache( | |
model.config.num_hidden_layers, | |
batch_size, | |
model.config.num_key_value_heads, | |
sequence_length, | |
getattr( | |
model.config, | |
"head_dim", | |
model.config.hidden_size // model.config.num_attention_heads, | |
), | |
"cpu", | |
) | |
n = len(cache.key_cache) | |
cache_length = torch.export.Dim("cache_length", min=1, max=4096) | |
shapes.update( | |
{ | |
"past_key_values": [ | |
[{0: batch, 2: cache_length} for _ in range(n)], | |
[{0: batch, 2: cache_length} for _ in range(n)], | |
], | |
} | |
) | |
dummy_inputs = (*dummy_inputs, cache) | |
return dummy_inputs, shapes | |
if __name__ == "__main__": | |
cache_dir = "/workspace/cache_dir/" | |
model_name = "meta-llama/Llama-3.2-1B-Instruct" | |
torch_dtype = torch.bfloat16 | |
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, attn_implementation="eager", cache_dir=cache_dir | |
).to(torch_dtype) | |
cos, sin = get_cos_sin_cache() | |
for module in model.modules(): | |
if "attention" in module.__class__.__name__.lower(): | |
module.__class__.forward = attention_forward | |
dummy_inputs, shapes = get_dummy_inputs_and_shapes() | |
with ( | |
torch.no_grad(), | |
torch_export_patches.bypass_export_some_errors(patch_transformers=True), | |
): | |
onnx_program = torch.onnx.export( | |
model, | |
args=dummy_inputs, | |
dynamic_shapes=shapes, | |
dynamo=True, | |
optimize=True, | |
) | |
model = onnx_program.model | |
clear_metadata_and_docstring.ClearMetadataAndDocStringPass()(model) | |
onnx_program.save("model.onnx") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment