Last active
April 13, 2021 20:27
stochastic_volatility.ipynb
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Author
twiecki
commented
Apr 12, 2021
via email
How do you mean?
…On Mon, Apr 12, 2021, 21:01 Brandon T. Willard ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
I just noticed that this example isn't optimizing the FunctionGraph.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<https://gist.github.com/a77104299535b64b58953de3c84df56f#gistcomment-3703154>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGGVJP47QKPRU54KYL3TIM7PFANCNFSM42Z3R32A>
.
Doing something like the following will optimize the FunctionGraph
in roughly the same way that aesara.function
does:
from aesara.compile.mode import FAST_RUN
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
_ = FAST_RUN.optimizer.optimize(fgraph)
Without that step, the JAX function will take the exact form of the log-likelihood graph determined by the Distribution.logp
implementations (i.e. no CSE, fusions, in-place operations, etc.).
I suppose pm.sample()
already does this?
This looks like something we need to update in PyMC3, as well.
Here's a quick comparison of the timing with and without graph optimizations (the example/model
is taken from this notebook):
fgraph = model.logp.f.maker.fgraph
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 198 µs ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 236 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment