Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Last active January 23, 2025 19:05
Show Gist options
  • Save justinchuby/822981ba3d48659d3d9bcb447d512aa1 to your computer and use it in GitHub Desktop.
Save justinchuby/822981ba3d48659d3d9bcb447d512aa1 to your computer and use it in GitHub Desktop.
Fold transpose nodes with ONNX IR
from onnxscript import ir
def fold_transpose_initializers(model: ir.Model):
for name, initializer in model.graph.initializers.items():
user_nodes = initializer.consumers()
if len(user_nodes) == 1 and user_nodes[0].op_type == "Transpose":
transpose_node = user_nodes[0]
perm = transpose_node.attributes.get("perm")
if perm is None:
transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose())
else:
transposed_tensor = ir.tensor(
initializer.const_value.numpy().transpose(perm.as_ints())
)
new_initializer = ir.Value(
name=initializer.name,
shape=transposed_tensor.shape,
type=ir.TensorType(transposed_tensor.dtype),
const_value=transposed_tensor,
)
ir.convenience.replace_all_uses_with(transpose_node.outputs[0], new_initializer)
model.graph.initializers[name] = new_initializer
transpose_node.graph.remove(transpose_node, safe=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment