Last active
December 1, 2019 22:21
-
-
Save davidADSP/277a4e4d4da7e6d8953bc055159999fd to your computer and use it in GitHub Desktop.
make_target (https://arxiv.org/src/1911.08265v1/anc/pseudocode.py)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Game(object): | |
"""A single episode of interaction with the environment.""" | |
def __init__(self, action_space_size: int, discount: float): | |
self.environment = Environment() # Game specific environment. | |
self.history = [] | |
self.rewards = [] | |
self.child_visits = [] | |
self.root_values = [] | |
self.action_space_size = action_space_size | |
self.discount = discount | |
def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int, | |
to_play: Player): | |
# The value target is the discounted root value of the search tree N steps | |
# into the future, plus the discounted sum of all rewards until then. | |
targets = [] | |
for current_index in range(state_index, state_index + num_unroll_steps + 1): | |
bootstrap_index = current_index + td_steps | |
if bootstrap_index < len(self.root_values): | |
value = self.root_values[bootstrap_index] * self.discount**td_steps | |
else: | |
value = 0 | |
for i, reward in enumerate(self.rewards[current_index:bootstrap_index]): | |
value += reward * self.discount**i # pytype: disable=unsupported-operands | |
if current_index < len(self.root_values): | |
targets.append((value, self.rewards[current_index], | |
self.child_visits[current_index])) | |
else: | |
# States past the end of games are treated as absorbing states. | |
targets.append((0, 0, [])) | |
return targets | |
... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment