Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Last active March 7, 2025 00:13
Show Gist options
  • Save justinchuby/6e4da5e82c72593053e89821f703070c to your computer and use it in GitHub Desktop.
Save justinchuby/6e4da5e82c72593053e89821f703070c to your computer and use it in GitHub Desktop.
Code for figuring out where an onnx model is inaccurate and visualize with model explorer
import logging
import torch
from torch_geometric.nn import GAT
logger = logging.getLogger(__name__)
logging.getLogger('torch.onnx').setLevel(logging.INFO)
logger.info("Prepare model")
num_features = 23
num_classes = 12
torch_path = "model.txt"
model = GAT(
in_channels=num_features,
out_channels=num_classes,
heads=4,
hidden_channels=16,
num_layers=1,
v2=True,
dropout=0.0,
)
best_model_ckpt = torch.load(torch_path, weights_only=False)
model.load_state_dict(best_model_ckpt)
model.eval()
device = torch.device("cpu")
model = model.to(device)
logger.info("Generating dummy data for ONNX exporter")
num_segments = 30
x = torch.randn(num_segments, num_features).to(device)
edge_index = torch.randint(num_segments, size=(2, 58)).to(device)
logger.info("Running torch model on dummy data")
with torch.no_grad():
result_torch = model(x, edge_index).numpy()
logger.info("Exporting")
dynamic_axes = {
"x": {0: "dynamic_input_features"},
"edge_index": {1: "dynamic_input_edge_connection"},
}
onnx_program = torch.onnx.export(
model,
(x, edge_index),
input_names=["x", "edge_index"],
dynamic_axes=dynamic_axes,
dynamo=True,
optimize=False,
)
onnx_program.save("torch_geometric.onnx")
from torch.onnx._internal.exporter._verification import verify_onnx_program
verification_infos = verify_onnx_program(onnx_program, compare_intermediates=True)
from model_explorer_onnx.torch_utils import save_node_data_from_verification_info
save_node_data_from_verification_info(
verification_infos, onnx_program.model, model_name="torch_geometric"
)
# onnxvis torch_geometric.onnx --node_data_paths=torch_geometric_max_rel_diff.json,torch_geometric_max_abs_diff.json
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment