Last active
August 23, 2019 15:39
-
-
Save tanzhenyu/480c643f61075943205582e306767235 to your computer and use it in GitHub Desktop.
PPO Buffer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def discount_cumsum(x, discount): | |
return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1] | |
def combined_shape(length, shape=None): | |
if shape is None: | |
return (length,) | |
return (length, shape) if np.isscalar(shape) else (length, *shape) | |
class PPOBuffer: | |
def __init__(self, ob_space, ac_space, size, gamma=0.99, lam=0.95): | |
self.obs_buf = np.zeros(combined_shape(size, ob_space.shape), dtype=ob_space.dtype) | |
self.act_buf = np.zeros(combined_shape(size, ac_space.shape), dtype=ac_space.dtype) | |
self.adv_buf = np.zeros(size, dtype=np.float32) | |
self.rew_buf = np.zeros(size, dtype=np.float32) | |
self.ret_buf = np.zeros(size, dtype=np.float32) | |
self.val_buf = np.zeros(size, dtype=np.float32) | |
self.logp_buf = np.zeros(size, dtype=np.float32) | |
self.gamma, self.lam = gamma, lam | |
self.ptr, self.path_start_idx, self.max_size = 0, 0, size | |
def store(self, obs, act, rew, val, logp): | |
assert self.ptr < self.max_size # buffer has to have room so you can store | |
self.obs_buf[self.ptr] = obs | |
self.act_buf[self.ptr] = act | |
self.rew_buf[self.ptr] = rew | |
self.val_buf[self.ptr] = val | |
self.logp_buf[self.ptr] = logp | |
self.ptr += 1 | |
def finish_path(self, last_val=0): | |
path_slice = slice(self.path_start_idx, self.ptr) | |
rews = np.append(self.rew_buf[path_slice], last_val) | |
vals = np.append(self.val_buf[path_slice], last_val) | |
# the next two lines implement GAE-Lambda advantage calculation | |
deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1] | |
self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam) | |
# the next line computes rewards-to-go, to be targets for the value function | |
self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)[:-1] | |
self.path_start_idx = self.ptr | |
def get(self): | |
assert self.ptr == self.max_size # buffer has to be full before you can get | |
self.ptr, self.path_start_idx = 0, 0 | |
# the next two lines implement the advantage normalization trick | |
adv_mean = np.mean(self.adv_buf) | |
adv_std = np.std(self.adv_buf) | |
self.adv_buf = (self.adv_buf - adv_mean) / adv_std | |
return [self.obs_buf, self.act_buf, self.adv_buf, | |
self.ret_buf, self.logp_buf] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment