Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Created March 7, 2025 19:58
Show Gist options
  • Save justinchuby/34f473819fffb38942c668edc5ffa6f0 to your computer and use it in GitHub Desktop.
Save justinchuby/34f473819fffb38942c668edc5ffa6f0 to your computer and use it in GitHub Desktop.
# Owner(s): ["module: onnx"]
"""Unit LLM tests for the onnx dynamo exporter."""
from __future__ import annotations
from typing import Any
import logging
import transformers
import torch
logger = logging.getLogger(__name__)
logging.getLogger('torch.onnx').setLevel(logging.INFO)
def _prepare_llm_model_gptj_to_test() -> tuple[
torch.nn.Module,
dict[str, Any],
dict[str, dict[int, str]],
list[str],
list[str],
]:
model = transformers.GPTJForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-gptj"
)
batch_size = 2
input_seq_len = 16
mask_seq_len = 32
active_prob = 0.5
vocab_size = 1000
# Generate random input_ids with values between 0 and vocab_size-1
input_ids = torch.randint(100, vocab_size, (batch_size, input_seq_len))
# Generate random attention_mask with values 0 or 1, where 1 indicates an active token
attention_mask = torch.bernoulli(
torch.full((batch_size, mask_seq_len), active_prob)
).int()
position_ids = torch.tensor(
[
[1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0],
]
)
past_key_values = [
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)),
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)),
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)),
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)),
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)),
]
kwargs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"past_key_values.0.key": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.0.value": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.1.key": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.1.value": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.2.key": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.2.value": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.3.key": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.3.value": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.4.key": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.4.value": {0: "batch_size", 2: "past_sequence_length"},
"attention_mask": {
0: "batch_size",
1: "past_sequence_length + sequence_length",
},
"position_ids": {0: "batch_size", 1: "sequence_length"},
"logits": {0: "batch_size", 1: "sequence_length"},
"present.0.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"},
"present.0.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.1.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"},
"present.1.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.2.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"},
"present.2.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.3.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"},
"present.3.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.4.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"},
"present.4.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
}
input_names = [
"input_ids",
"past_key_values.0.key",
"past_key_values.0.value",
"past_key_values.1.key",
"past_key_values.1.value",
"past_key_values.2.key",
"past_key_values.2.value",
"past_key_values.3.key",
"past_key_values.3.value",
"past_key_values.4.key",
"past_key_values.4.value",
"attention_mask",
"position_ids",
]
output_names = [
"logits",
"present.0.key",
"present.0.value",
"present.1.key",
"present.1.value",
"present.2.key",
"present.2.value",
"present.3.key",
"present.3.value",
"present.4.key",
"present.4.value",
]
return model, kwargs, dynamic_axes, input_names, output_names
model, kwargs, dynamic_axes, input_names, output_names = (
_prepare_llm_model_gptj_to_test()
)
onnx_program = torch.onnx.export(
model,
kwargs=kwargs,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
dynamo=True,
)
onnx_program.save("llm_model_gptj.onnx")
from torch.onnx._internal.exporter._verification import verify_onnx_program
verification_infos = verify_onnx_program(onnx_program, compare_intermediates=True)
from model_explorer_onnx.torch_utils import save_node_data_from_verification_info
save_node_data_from_verification_info(
verification_infos, onnx_program.model, model_name="llm_model_gptj"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment