Created
March 7, 2025 19:58
-
-
Save justinchuby/34f473819fffb38942c668edc5ffa6f0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Owner(s): ["module: onnx"] | |
"""Unit LLM tests for the onnx dynamo exporter.""" | |
from __future__ import annotations | |
from typing import Any | |
import logging | |
import transformers | |
import torch | |
logger = logging.getLogger(__name__) | |
logging.getLogger('torch.onnx').setLevel(logging.INFO) | |
def _prepare_llm_model_gptj_to_test() -> tuple[ | |
torch.nn.Module, | |
dict[str, Any], | |
dict[str, dict[int, str]], | |
list[str], | |
list[str], | |
]: | |
model = transformers.GPTJForCausalLM.from_pretrained( | |
"hf-internal-testing/tiny-random-gptj" | |
) | |
batch_size = 2 | |
input_seq_len = 16 | |
mask_seq_len = 32 | |
active_prob = 0.5 | |
vocab_size = 1000 | |
# Generate random input_ids with values between 0 and vocab_size-1 | |
input_ids = torch.randint(100, vocab_size, (batch_size, input_seq_len)) | |
# Generate random attention_mask with values 0 or 1, where 1 indicates an active token | |
attention_mask = torch.bernoulli( | |
torch.full((batch_size, mask_seq_len), active_prob) | |
).int() | |
position_ids = torch.tensor( | |
[ | |
[1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1], | |
[1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0], | |
] | |
) | |
past_key_values = [ | |
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)), | |
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)), | |
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)), | |
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)), | |
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)), | |
] | |
kwargs = { | |
"input_ids": input_ids, | |
"past_key_values": past_key_values, | |
"attention_mask": attention_mask, | |
"position_ids": position_ids, | |
} | |
dynamic_axes = { | |
"input_ids": {0: "batch_size", 1: "sequence_length"}, | |
"past_key_values.0.key": {0: "batch_size", 2: "past_sequence_length"}, | |
"past_key_values.0.value": {0: "batch_size", 2: "past_sequence_length"}, | |
"past_key_values.1.key": {0: "batch_size", 2: "past_sequence_length"}, | |
"past_key_values.1.value": {0: "batch_size", 2: "past_sequence_length"}, | |
"past_key_values.2.key": {0: "batch_size", 2: "past_sequence_length"}, | |
"past_key_values.2.value": {0: "batch_size", 2: "past_sequence_length"}, | |
"past_key_values.3.key": {0: "batch_size", 2: "past_sequence_length"}, | |
"past_key_values.3.value": {0: "batch_size", 2: "past_sequence_length"}, | |
"past_key_values.4.key": {0: "batch_size", 2: "past_sequence_length"}, | |
"past_key_values.4.value": {0: "batch_size", 2: "past_sequence_length"}, | |
"attention_mask": { | |
0: "batch_size", | |
1: "past_sequence_length + sequence_length", | |
}, | |
"position_ids": {0: "batch_size", 1: "sequence_length"}, | |
"logits": {0: "batch_size", 1: "sequence_length"}, | |
"present.0.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"}, | |
"present.0.value": { | |
0: "batch_size", | |
2: "past_sequence_length + sequence_length", | |
}, | |
"present.1.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"}, | |
"present.1.value": { | |
0: "batch_size", | |
2: "past_sequence_length + sequence_length", | |
}, | |
"present.2.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"}, | |
"present.2.value": { | |
0: "batch_size", | |
2: "past_sequence_length + sequence_length", | |
}, | |
"present.3.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"}, | |
"present.3.value": { | |
0: "batch_size", | |
2: "past_sequence_length + sequence_length", | |
}, | |
"present.4.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"}, | |
"present.4.value": { | |
0: "batch_size", | |
2: "past_sequence_length + sequence_length", | |
}, | |
} | |
input_names = [ | |
"input_ids", | |
"past_key_values.0.key", | |
"past_key_values.0.value", | |
"past_key_values.1.key", | |
"past_key_values.1.value", | |
"past_key_values.2.key", | |
"past_key_values.2.value", | |
"past_key_values.3.key", | |
"past_key_values.3.value", | |
"past_key_values.4.key", | |
"past_key_values.4.value", | |
"attention_mask", | |
"position_ids", | |
] | |
output_names = [ | |
"logits", | |
"present.0.key", | |
"present.0.value", | |
"present.1.key", | |
"present.1.value", | |
"present.2.key", | |
"present.2.value", | |
"present.3.key", | |
"present.3.value", | |
"present.4.key", | |
"present.4.value", | |
] | |
return model, kwargs, dynamic_axes, input_names, output_names | |
model, kwargs, dynamic_axes, input_names, output_names = ( | |
_prepare_llm_model_gptj_to_test() | |
) | |
onnx_program = torch.onnx.export( | |
model, | |
kwargs=kwargs, | |
input_names=input_names, | |
output_names=output_names, | |
dynamic_axes=dynamic_axes, | |
dynamo=True, | |
) | |
onnx_program.save("llm_model_gptj.onnx") | |
from torch.onnx._internal.exporter._verification import verify_onnx_program | |
verification_infos = verify_onnx_program(onnx_program, compare_intermediates=True) | |
from model_explorer_onnx.torch_utils import save_node_data_from_verification_info | |
save_node_data_from_verification_info( | |
verification_infos, onnx_program.model, model_name="llm_model_gptj" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment