def f_linearized(x):
_, f_lin = jax.lax.stop_gradient(jax.linearize(f, x))
return f_lin(x)
This example wants to differentiate with respect to intermediate variables (s_next) inside the scan.
To do this, we pass a sequence of perturbation variables (zero_states
), that are added to s (s + z
), and are set to 0
when calling the function so the calculation is not changed,
but differentiating with respect to zero_states
, differentiates with respect to s_next
.
def step_function(carry, a_z):
a, z = a_z
ctx, s = carry
ctx_next, s_next, o_next, reward = f_step(ctx, s + z, a)
return (ctx_next, s_next), (reward, V(o_next))
(ctx_final, s_final), (rewards, values_rest) = lax.scan(step_function, (ctx, s), (actions, zero_states))