Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Last active May 23, 2025 04:35
Show Gist options
  • Save justinchuby/c8d84a3c21b2651d72b7824740b6f2f2 to your computer and use it in GitHub Desktop.
Save justinchuby/c8d84a3c21b2651d72b7824740b6f2f2 to your computer and use it in GitHub Desktop.
Stable HLO
from ai_edge_torch.odml_torch.export import exported_program_to_mlir
import torch
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x ** 0.5
model = PowModel()
print(model(torch.tensor(2)))
prog = torch.export.export(PowModel(), (torch.tensor(2),))
model = exported_program_to_mlir(prog)
print(model.module.body.operations[0].body.blocks[0].region.blocks[0].arguments[0])
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0])
# %0 = func.call @_aten_pow_fd1ba3a3(%arg0) : (tensor<i64>) -> tensor<f32>
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[1])
# func.return %0 : tensor<f32>
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].get_asm())
# %0 = func.call @_aten_pow_fd1ba3a3(%arg0) : (tensor<i64>) -> tensor<f32>
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0])
# %0 = func.call @_aten_pow_fd1ba3a3(%arg0) : (tensor<i64>) -> tensor<f32>
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].result)
# Value(%0 = func.call @_aten_pow_fd1ba3a3(%arg0) : (tensor<i64>) -> tensor<f32>)
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].results)
# <jaxlib.mlir._mlir_libs._mlir.ir.OpResultList object at 0x73bb1c2934b0>
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].results[0])
# Value(%0 = func.call @_aten_pow_fd1ba3a3(%arg0) : (tensor<i64>) -> tensor<f32>)
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].attributes)
# <jaxlib.mlir._mlir_libs._mlir.ir.OpAttributeMap object at 0x73bd183de580>
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].attributes[0])
# NamedAttribute(callee=@_aten_pow_fd1ba3a3)
# >>> print(model.module.body.operations[0].body.blocks[0].region.blocks[0].operations[0].attributes[1])
@justinchuby
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment