Last active
April 13, 2025 15:12
-
-
Save justinchuby/cf1699d05baeac281fb3e82f9d0fc473 to your computer and use it in GitHub Desktop.
Convert constants to initializers with ONNX IR
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ************************************************************************************** | |
# NOTE: Users can now use https://github.com/microsoft/onnxscript/blob/main/onnxscript/ir/passes/common/constant_manipulation.py | |
# aka. onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass | |
# ************************************************************************************** | |
from onnxscript import ir | |
def convert_constants_to_initizliers(model: ir.Model, size_limit: int = 1024): | |
"""Convert constant nodes to initializers.""" | |
for node in model.graph: | |
if node.op_type != "Constant": | |
continue | |
if "value" not in node.attributes: | |
continue | |
tensor = node.attributes["value"].as_tensor() | |
if tensor.size < size_limit: | |
continue | |
# Register an initializer with the tensor value | |
initializer_name = node.outputs[0].name | |
assert initializer_name is not None | |
initializer = ir.Value( | |
name=initializer_name, | |
shape=tensor.shape, | |
type=ir.TensorType(tensor.dtype), | |
const_value=tensor, | |
) | |
model.graph.initializers[initializer_name] = initializer | |
# Replace the constant node with the initilizer | |
ir.convenience.replace_all_uses_with(node.outputs[0], initializer) | |
model.graph.remove(node, safe=True) | |
print(f"Converted constant node '{node.name}' to initializer '{initializer_name}'") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment