Last active
March 7, 2025 00:13
-
-
Save justinchuby/6e4da5e82c72593053e89821f703070c to your computer and use it in GitHub Desktop.
Code for figuring out where an onnx model is inaccurate and visualize with model explorer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import torch | |
from torch_geometric.nn import GAT | |
logger = logging.getLogger(__name__) | |
logging.getLogger('torch.onnx').setLevel(logging.INFO) | |
logger.info("Prepare model") | |
num_features = 23 | |
num_classes = 12 | |
torch_path = "model.txt" | |
model = GAT( | |
in_channels=num_features, | |
out_channels=num_classes, | |
heads=4, | |
hidden_channels=16, | |
num_layers=1, | |
v2=True, | |
dropout=0.0, | |
) | |
best_model_ckpt = torch.load(torch_path, weights_only=False) | |
model.load_state_dict(best_model_ckpt) | |
model.eval() | |
device = torch.device("cpu") | |
model = model.to(device) | |
logger.info("Generating dummy data for ONNX exporter") | |
num_segments = 30 | |
x = torch.randn(num_segments, num_features).to(device) | |
edge_index = torch.randint(num_segments, size=(2, 58)).to(device) | |
logger.info("Running torch model on dummy data") | |
with torch.no_grad(): | |
result_torch = model(x, edge_index).numpy() | |
logger.info("Exporting") | |
dynamic_axes = { | |
"x": {0: "dynamic_input_features"}, | |
"edge_index": {1: "dynamic_input_edge_connection"}, | |
} | |
onnx_program = torch.onnx.export( | |
model, | |
(x, edge_index), | |
input_names=["x", "edge_index"], | |
dynamic_axes=dynamic_axes, | |
dynamo=True, | |
optimize=False, | |
) | |
onnx_program.save("torch_geometric.onnx") | |
from torch.onnx._internal.exporter._verification import verify_onnx_program | |
verification_infos = verify_onnx_program(onnx_program, compare_intermediates=True) | |
from model_explorer_onnx.torch_utils import save_node_data_from_verification_info | |
save_node_data_from_verification_info( | |
verification_infos, onnx_program.model, model_name="torch_geometric" | |
) | |
# onnxvis torch_geometric.onnx --node_data_paths=torch_geometric_max_rel_diff.json,torch_geometric_max_abs_diff.json |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment