Created
February 25, 2022 06:36
-
-
Save zhangqiaorjc/c08ec7b28dabeee37a01cd85c7da6352 to your computer and use it in GitHub Desktop.
make_hlo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def make_hlo(f, optimize=False, metadata=False, platform=None): | |
"""Utility function for printing JAX-emitted HLO and XLA-compiled HLO. | |
Args: | |
f: jax function to return hlo for. | |
optimize: bool: whether to return platform-specific, XLA-optimized HLO | |
metadata: bool: whether to include JAX metadata information | |
platform: Optional[str]: None, 'cpu','gpu','tpu' - platform to compile for, | |
None uses default. | |
Returns: | |
str: HLO in text format. | |
""" | |
client = jax.lib.xla_bridge.get_backend(platform) | |
print_opts = jax.lib.xla_client._xla.HloPrintOptions.short_parsable() | |
print_opts.print_metadata = metadata | |
def wrapped_fn(*args, **kwargs): | |
c = jax.xla_computation(f)(*args, **kwargs) | |
if optimize: | |
return client.compile(c).hlo_modules()[0].to_string(print_opts) | |
else: | |
return c.as_hlo_module().to_string(print_opts) | |
return wrapped_fn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment