Last active
April 24, 2021 06:59
-
-
Save pierrelux/5edf0fcb845e7d8213888a925c0d58e7 to your computer and use it in GitHub Desktop.
Four ways to compute discounted returns
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as onp | |
from scipy.signal import lfilter | |
import jax | |
import jax.numpy as jnp | |
def discount_lfilter(rewards, discount): | |
return lfilter(b=[1], a=[1, -discount], x=rewards[::-1])[::-1] | |
def discount_correlate(rewards, discount): | |
nsamples = rewards.shape[0] | |
return onp.correlate(rewards, onp.power(discount, onp.arange(nsamples)),'full')[-nsamples:] | |
def discount_convolve(rewards, discount): | |
nsamples = rewards.shape[0] | |
return onp.convolve(rewards[::-1], onp.power(discount, onp.arange(nsamples)),'full')[:nsamples][::-1] | |
def convolve1D(x, y): | |
# Based on https://github.com/google/jax/issues/1561 | |
x_jax = jnp.reshape(x,(1,1,len(x))) | |
y_jax = jnp.flip(jnp.reshape(y,(1,1,len(y))),2) | |
return jnp.ravel(jax.lax.conv_general_dilated(x_jax,y_jax,[1],[(len(x)-1,len(x)-1)])) | |
def discount_convolve_jax(rewards, discount): | |
nsamples = rewards.shape[0] | |
discount_sequence = jnp.power(discount, jnp.arange(nsamples)) | |
filtered_rewards = convolve1D(rewards[::-1], discount_sequence) | |
return filtered_rewards[:nsamples][::-1] | |
if __name__ == "__main__": | |
discount = 0.9 | |
rewards = onp.array([1,2,3,4], dtype=float) | |
print(discount_lfilter(rewards, discount)) | |
print(discount_correlate(rewards, discount)) | |
print(discount_convolve(rewards, discount)) | |
print(discount_convolve_jax(rewards, discount)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment