Last active
August 13, 2019 22:41
-
-
Save pierrelux/059bfef496354cd8fa0ff4557db3b58b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax import jvp, grad | |
def f(x,y): | |
return x + y**2 | |
def freeze(f, argnum, val): | |
def _f(arg): | |
args = [val, arg] if argnum == 0 else [arg, val] | |
return f(*args) | |
return _f | |
def mixed_jvp(f, order, primals, tangents): | |
frozen_func = freeze(grad(f, order[0]), argnum=order[1], val=primals[order[0]]) | |
return jvp(frozen_func, (primals[order[1]],), tangents) | |
mixed_jvp(f, order=(0,1), primals=(2., 3.), tangents=(1.,)) |
gehring
commented
Aug 13, 2019
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment