Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active April 3, 2026 08:55
Show Gist options
  • Select an option

  • Save wassname/59f1df704bd5763507e9fcfa1898b310 to your computer and use it in GitHub Desktop.

Select an option

Save wassname/59f1df704bd5763507e9fcfa1898b310 to your computer and use it in GitHub Desktop.
snippet showing how to do a smoke test including beartype and jaxtyping only on smoke test
# Smoke test: demo on task 0 showing steered outputs at -1, 0, +1
smoke *ARGS:
BEARTYPE=1 {{ PY }} ssteer_v3.py --quick {{ ARGS }} 2>&1 | tee /tmp/smoke.log | tail -80
...
if os.environ.get("BEARTYPE"):
from beartype import beartype as typechecker
from jaxtyping import install_import_hook, jaxtyped
install_import_hook(["model"], "beartype.beartype")
def typecheck(fn):
"""Wraps train.py functions (entry point, not covered by import hook)."""
return jaxtyped(fn, typechecker=typechecker)
else:
def typecheck(fn):
return fn
@typecheck
def compute_shift_loss(
s_ref_cho: Float[jax.Array, 'p r'], # p = n_pairs = b/2
s_ref_rej: Float[jax.Array, 'p r'],
s_pos_cho: Float[jax.Array, 'p r'],
....
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment