Last active
May 23, 2025 04:35
-
-
Save justinchuby/c8d84a3c21b2651d72b7824740b6f2f2 to your computer and use it in GitHub Desktop.
Stable HLO
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ai_edge_torch.odml_torch.export import exported_program_to_mlir | |
import torch | |
class PowModel(torch.nn.Module): | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return x ** 0.5 | |
model = PowModel() | |
print(model(torch.tensor(2))) | |
prog = torch.export.export(PowModel(), (torch.tensor(2),)) | |
model = exported_program_to_mlir(prog) | |
print(model.module.body.operations[0].body.blocks[0].region.blocks[0].arguments[0]) | |
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0]) | |
# %0 = func.call @_aten_pow_fd1ba3a3(%arg0) : (tensor<i64>) -> tensor<f32> | |
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[1]) | |
# func.return %0 : tensor<f32> | |
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].get_asm()) | |
# %0 = func.call @_aten_pow_fd1ba3a3(%arg0) : (tensor<i64>) -> tensor<f32> | |
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0]) | |
# %0 = func.call @_aten_pow_fd1ba3a3(%arg0) : (tensor<i64>) -> tensor<f32> | |
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].result) | |
# Value(%0 = func.call @_aten_pow_fd1ba3a3(%arg0) : (tensor<i64>) -> tensor<f32>) | |
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].results) | |
# <jaxlib.mlir._mlir_libs._mlir.ir.OpResultList object at 0x73bb1c2934b0> | |
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].results[0]) | |
# Value(%0 = func.call @_aten_pow_fd1ba3a3(%arg0) : (tensor<i64>) -> tensor<f32>) | |
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].attributes) | |
# <jaxlib.mlir._mlir_libs._mlir.ir.OpAttributeMap object at 0x73bd183de580> | |
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].attributes[0]) | |
# NamedAttribute(callee=@_aten_pow_fd1ba3a3) | |
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].attributes[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
ONNX to torch (??) https://github.com/llvm/torch-mlir/blob/main/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp