Skip to content

Instantly share code, notes, and snippets.

@twiecki
Last active April 13, 2021 20:27
stochastic_volatility.ipynb
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@twiecki
Copy link
Author

twiecki commented Apr 12, 2021 via email

@brandonwillard
Copy link

brandonwillard commented Apr 12, 2021

Doing something like the following will optimize the FunctionGraph in roughly the same way that aesara.function does:

from aesara.compile.mode import FAST_RUN

fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
_ = FAST_RUN.optimizer.optimize(fgraph)

Without that step, the JAX function will take the exact form of the log-likelihood graph determined by the Distribution.logp implementations (i.e. no CSE, fusions, in-place operations, etc.).

@twiecki
Copy link
Author

twiecki commented Apr 13, 2021

I suppose pm.sample() already does this?

@brandonwillard
Copy link

This looks like something we need to update in PyMC3, as well.

Here's a quick comparison of the timing with and without graph optimizations (the example/model is taken from this notebook):

fgraph = model.logp.f.maker.fgraph
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 198 µs ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 236 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment