Created
July 19, 2019 17:59
-
-
Save pierrelux/78b9a27aa8d604911a00b0fb00b5f7b6 to your computer and use it in GitHub Desktop.
VJP through a specific lfilter performing discounting on an array of scalar elements (rewards).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import autograd.numpy as np | |
from scipy.signal import lfilter | |
from autograd.extend import primitive, defvjp | |
@primitive | |
def accumulate_discounted(rewards, discount=1.): | |
"""Behaves like `accumulate` but where each array element gets discounted. | |
Args: | |
rewards (np.ndarray): 1D array of rewards | |
discount (float): Scalar discount factor | |
Returns: | |
np.ndarray: 1D array of the same size as `rewards`, where each element is the discounted | |
sum of rewards onward. | |
""" | |
if rewards.ndim == 1: | |
rewards = np.expand_dims(rewards, 0) | |
return lfilter([1], [1, -discount], x=rewards[:, ::-1])[:, ::-1] | |
def accumulate_discounted_vjp(ans, rewards, discount=1.): | |
""" Vector-Jacobian product wrt. to the "rewards" vector | |
The output is $y_n = \\sum_{i=0}^n \\bar{x}_i \\gamma^{n-i}$ (a convolution). | |
""" | |
del ans | |
del rewards | |
def _vjp(xbar): | |
return lfilter([1], [1, -discount], x=xbar) | |
return _vjp | |
defvjp(accumulate_discounted, accumulate_discounted_vjp) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment