Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Created January 16, 2025 23:28
Show Gist options
  • Save justinchuby/9bc72685b8622e5e62ade5c41ed3f62e to your computer and use it in GitHub Desktop.
Save justinchuby/9bc72685b8622e5e62ade5c41ed3f62e to your computer and use it in GitHub Desktop.
Rename dynamic axes in ONNX model with IR
from onnxscript import ir
def _all_values(model: ir.Model):
"""Yield all values in a model."""
yield from model.graph.inputs
yield from model.graph.initializers.values()
for node in ir.traversal.RecursiveGraphIterator(model.graph):
yield from node.outputs
def _create_rename_mapping(
input_outputs, dynamic_axes: dict[str, dict[int, str]]
) -> dict[str, str]:
"""Create a mapping from old names to new names for dynamic axes."""
named_ios = {v.name: v for v in input_outputs}
rename_mapping = {}
for input_name, axes in dynamic_axes.items():
if input_name not in named_ios:
continue
input = named_ios[input_name]
for dim, new_name in axes.items():
if not isinstance(input.shape[dim], ir.SymbolicDim):
logging.warning(
ValueError(
f"Dimension {dim} of input '{input_name}' is not dynamic: {input.shape[dim]}"
)
)
continue
old_name = input.shape[dim].value
if old_name is None:
continue
rename_mapping[input.shape[dim].value] = new_name
return rename_mapping
def _replace_names(shape_expr: str, rename_mapping: dict[str, str]) -> str:
"""Replace all known names in a shape expression with new names."""
for old_name, new_name in rename_mapping.items():
shape_expr = re.sub(rf"\b{old_name}\b", new_name, shape_expr)
return shape_expr
def rename_dynamic_axes(
model: ir.Model, dynamic_axes: dict[str, dict[int, str]]
) -> None:
"""Rename dynamic axes in a model according to the specified dynamic_axes names."""
rename_mapping = _create_rename_mapping(
(*model.graph.inputs, *model.graph.outputs), dynamic_axes
)
for value in _all_values(model):
if value.shape is None:
continue
new_shape = []
changed = False
for dim in value.shape:
if not isinstance(dim, ir.SymbolicDim):
new_shape.append(dim)
continue
dim_name = dim.value
if dim_name in rename_mapping:
new_shape.append(rename_mapping[dim_name])
changed = True
elif dim_name is not None:
new_name = _replace_names(dim_name, rename_mapping)
new_shape.append(new_name)
if new_name != dim_name:
changed = True
else:
new_shape.append(None)
if changed:
value.shape = ir.Shape(new_shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment