Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import os, re | |
import jax | |
from jax.experimental import maps | |
from jax.experimental import pjit | |
import jax.numpy as jnp | |
from jax.experimental import mesh_utils | |
from absl import flags |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Decoder-only LM scaling experiments on GPUs.""" | |
from jax import numpy as jnp | |
from paxml import experiment_registry | |
from paxml.tasks.lm.params.lm_cloud import LmCloudSpmd | |
from paxml.tasks.lm.params.lm_cloud import LmCloudSpmdPipeline | |
from praxis import layers | |
# TODO(zhangqiaorjc): Might need to use pmap instead of pjit for smaller models. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Runs a simple mnist model with fake FP8. FP8 scaling is used. | |
The HLO can be dumped by setting the environment variable: | |
XLA_FLAGS='--xla_dump_disable_metadata=true --xla_dump_to=/tmp/hlo' | |
""" | |
import tensorflow as tf |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from absl.testing import absltest | |
from absl import logging | |
import jax | |
import jax.numpy as jnp | |
def amax(x): | |
return jnp.max(jnp.abs(x)) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def make_hlo(f, optimize=False, metadata=False, platform=None): | |
"""Utility function for printing JAX-emitted HLO and XLA-compiled HLO. | |
Args: | |
f: jax function to return hlo for. | |
optimize: bool: whether to return platform-specific, XLA-optimized HLO | |
metadata: bool: whether to include JAX metadata information | |
platform: Optional[str]: None, 'cpu','gpu','tpu' - platform to compile for, | |
None uses default. |
NewerOlder