This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SelfPlayEnv(env): | |
# ... | |
def step(self, action): | |
self.render() | |
observation, reward, done, _ = super(SelfPlayEnv, self).step(action) | |
logger.debug(f'Action played by agent: {action}') | |
logger.debug(f'Rewards: {reward}') | |
logger.debug(f'Done: {done}') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SelfPlayEnv(env): | |
# ... | |
def continue_game(self): | |
while self.current_player_num != self.agent_player_num: | |
self.render() | |
action = self.current_agent.choose_action(self, choose_best_action = False, mask_invalid_actions = False) | |
observation, reward, done, _ = super(SelfPlayEnv, self).step(action) | |
logger.debug(f'Rewards: {reward}') | |
logger.debug(f'Done: {done}') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SelfPlayEnv(env): | |
# ... | |
def reset(self): | |
super(SelfPlayEnv, self).reset() | |
self.setup_opponents() | |
if self.current_player_num != self.agent_player_num: | |
self.continue_game() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from stable_baselines import PPO1 | |
from stable_baselines.common.policies import MlpPolicy | |
from stable_baselines.common.callbacks import EvalCallback | |
env = gym.make('Pendulum-v0') | |
model = PPO1(MlpPolicy, env) | |
# Separate evaluation env |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import networkx as nx | |
# SAMPLE DATA FORMAT | |
#nodes = [('tensorflow', {'count': 13}), | |
# ('pytorch', {'count': 6}), | |
# ('keras', {'count': 6}), | |
# ('scikit', {'count': 2}), | |
# ('opencv', {'count': 5}), | |
# ('spark', {'count': 13}), ...] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ReplayBuffer(object): | |
def __init__(self, config: MuZeroConfig): | |
self.window_size = config.window_size | |
self.batch_size = config.batch_size | |
self.buffer = [] | |
def sample_batch(self, num_unroll_steps: int, td_steps: int): | |
games = [self.sample_game() for _ in range(self.batch_size)] | |
game_pos = [(g, self.sample_position(g)) for g in games] | |
return [(g.make_image(i), g.history[i:i + num_unroll_steps], |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Game(object): | |
"""A single episode of interaction with the environment.""" | |
def __init__(self, action_space_size: int, discount: float): | |
self.environment = Environment() # Game specific environment. | |
self.history = [] | |
self.rewards = [] | |
self.child_visits = [] | |
self.root_values = [] | |
self.action_space_size = action_space_size |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update_weights(optimizer: tf.train.Optimizer, network: Network, batch, | |
weight_decay: float): | |
loss = 0 | |
for image, actions, targets in batch: | |
# Initial step, from the real observation. | |
value, reward, policy_logits, hidden_state = network.initial_inference( | |
image) | |
predictions = [(1.0, value, reward, policy_logits)] | |
# Recurrent steps, from action and previous hidden state. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def select_action(config: MuZeroConfig, num_moves: int, node: Node, | |
network: Network): | |
visit_counts = [ | |
(child.visit_count, action) for action, child in node.children.items() | |
] | |
t = config.visit_softmax_temperature_fn( | |
num_moves=num_moves, training_steps=network.training_steps()) | |
_, action = softmax_sample(visit_counts, t) | |
return action |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# At the end of a simulation, we propagate the evaluation all the way up the | |
# tree to the root. | |
def backpropagate(search_path: List[Node], value: float, to_play: Player, | |
discount: float, min_max_stats: MinMaxStats): | |
for node in search_path: | |
node.value_sum += value if node.to_play == to_play else -value | |
node.visit_count += 1 | |
min_max_stats.update(node.value()) | |
value = node.reward + discount * value |
NewerOlder