Skip to content

Instantly share code, notes, and snippets.

@jkbjh
Created June 24, 2025 10:41
Show Gist options
  • Save jkbjh/584898e5c2b12063a144203d58ecdf01 to your computer and use it in GitHub Desktop.
Save jkbjh/584898e5c2b12063a144203d58ecdf01 to your computer and use it in GitHub Desktop.
jax cookbook

One-time differentiable function (locally linearized function):

def f_linearized(x):
    _, f_lin = jax.lax.stop_gradient(jax.linearize(f, x))
    return f_lin(x)

Differentiate with respect to internal variable ("perturbation trick"):

This example wants to differentiate with respect to intermediate variables (s_next) inside the scan. To do this, we pass a sequence of perturbation variables (zero_states), that are added to s (s + z), and are set to 0 when calling the function so the calculation is not changed, but differentiating with respect to zero_states, differentiates with respect to s_next.

    def step_function(carry, a_z):
        a, z = a_z
        ctx, s = carry
        ctx_next, s_next, o_next, reward = f_step(ctx, s + z, a)
        return (ctx_next, s_next), (reward, V(o_next))
        
    (ctx_final, s_final), (rewards, values_rest) = lax.scan(step_function, (ctx, s), (actions, zero_states))
        
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment