Created
January 16, 2025 23:28
-
-
Save justinchuby/9bc72685b8622e5e62ade5c41ed3f62e to your computer and use it in GitHub Desktop.
Rename dynamic axes in ONNX model with IR
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from onnxscript import ir | |
def _all_values(model: ir.Model): | |
"""Yield all values in a model.""" | |
yield from model.graph.inputs | |
yield from model.graph.initializers.values() | |
for node in ir.traversal.RecursiveGraphIterator(model.graph): | |
yield from node.outputs | |
def _create_rename_mapping( | |
input_outputs, dynamic_axes: dict[str, dict[int, str]] | |
) -> dict[str, str]: | |
"""Create a mapping from old names to new names for dynamic axes.""" | |
named_ios = {v.name: v for v in input_outputs} | |
rename_mapping = {} | |
for input_name, axes in dynamic_axes.items(): | |
if input_name not in named_ios: | |
continue | |
input = named_ios[input_name] | |
for dim, new_name in axes.items(): | |
if not isinstance(input.shape[dim], ir.SymbolicDim): | |
logging.warning( | |
ValueError( | |
f"Dimension {dim} of input '{input_name}' is not dynamic: {input.shape[dim]}" | |
) | |
) | |
continue | |
old_name = input.shape[dim].value | |
if old_name is None: | |
continue | |
rename_mapping[input.shape[dim].value] = new_name | |
return rename_mapping | |
def _replace_names(shape_expr: str, rename_mapping: dict[str, str]) -> str: | |
"""Replace all known names in a shape expression with new names.""" | |
for old_name, new_name in rename_mapping.items(): | |
shape_expr = re.sub(rf"\b{old_name}\b", new_name, shape_expr) | |
return shape_expr | |
def rename_dynamic_axes( | |
model: ir.Model, dynamic_axes: dict[str, dict[int, str]] | |
) -> None: | |
"""Rename dynamic axes in a model according to the specified dynamic_axes names.""" | |
rename_mapping = _create_rename_mapping( | |
(*model.graph.inputs, *model.graph.outputs), dynamic_axes | |
) | |
for value in _all_values(model): | |
if value.shape is None: | |
continue | |
new_shape = [] | |
changed = False | |
for dim in value.shape: | |
if not isinstance(dim, ir.SymbolicDim): | |
new_shape.append(dim) | |
continue | |
dim_name = dim.value | |
if dim_name in rename_mapping: | |
new_shape.append(rename_mapping[dim_name]) | |
changed = True | |
elif dim_name is not None: | |
new_name = _replace_names(dim_name, rename_mapping) | |
new_shape.append(new_name) | |
if new_name != dim_name: | |
changed = True | |
else: | |
new_shape.append(None) | |
if changed: | |
value.shape = ir.Shape(new_shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment