Created
July 22, 2024 17:53
-
-
Save r33drichards/fa261eb51209391070aa602e7da0ed5d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import gymnasium as gym | |
import math | |
import random | |
import matplotlib | |
import matplotlib.pyplot as plt | |
from collections import namedtuple, deque | |
from itertools import count | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from collections import defaultdict | |
from dataclasses import dataclass, field | |
import gymnasium as gym | |
from env import TicTacToeEnv | |
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward")) | |
class ReplayMemory(object): | |
def __init__(self, capacity): | |
self.memory = deque([], maxlen=capacity) | |
def push(self, *args): | |
"""Save a transition""" | |
self.memory.append(Transition(*args)) | |
def sample(self, batch_size): | |
return random.sample(self.memory, batch_size) | |
def __len__(self): | |
return len(self.memory) | |
class DQN(nn.Module): | |
def __init__(self, n_observations, n_actions): | |
super(DQN, self).__init__() | |
self.layer1 = nn.Linear(n_observations, n_observations*2) | |
self.layer2 = nn.Linear(n_observations*2, n_observations*2) | |
self.layer3 = nn.Linear(n_observations*2, n_actions) | |
# Called with either one element to determine next action, or a batch | |
# during optimization. Returns tensor([[left0exp,right0exp]...]). | |
def forward(self, x): | |
x = F.relu(self.layer1(x)) | |
x = F.relu(self.layer2(x)) | |
return self.layer3(x) | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# Define parameters | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# Define the function to calculate epsilon | |
def epsilon(steps_done, EPS_START, EPS_END, DECAY_RATE): | |
return EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / DECAY_RATE) | |
# Define the hybrid exploration rate function | |
def hybrid_exploration(steps_done, EPS_START, EPS_INTER, EPS_END, DECAY_STEPS, T_MAX, PEAK_DECAY_RATE): | |
DECAY_RATE = 800000 # Decay rate | |
if steps_done < DECAY_STEPS: | |
# exponential decay phase | |
return EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / DECAY_RATE) | |
else: | |
# Adjusted amplitude for cosine annealing that decreases over steps | |
adjusted_max = EPS_INTER - (EPS_INTER - EPS_END) * ((steps_done - DECAY_STEPS) / 1000000) | |
current_amplitude = max(adjusted_max - EPS_END, 0) | |
# Cosine annealing phase with decreasing peaks | |
return EPS_END + 0.5 * current_amplitude * (1 + np.cos(np.pi * (steps_done - DECAY_STEPS) % T_MAX / T_MAX)) | |
# Generate data for steps_done | |
# steps = np.arange(0, 100_000) # For example, from 0 to 2000 steps to show two cycles | |
# epsilons = [cosine_annealing(step, EPS_START, EPS_END, T_MAX) for step in steps] | |
# # Plotting | |
# plt.figure(figsize=(10, 5)) | |
# plt.plot(steps, epsilons, label='Epsilon') | |
# plt.title('Cosine Annealing for Exploration Rate') | |
# plt.xlabel('Steps Done') | |
# plt.ylabel('Epsilon') | |
# plt.grid(True) | |
# plt.legend() | |
# plt.show() | |
@dataclass | |
class Agent: | |
env: gym.Env | |
reward_history: list = field(default_factory=list) | |
BATCH_SIZE: int = 1000 # number of transitions sampled from replay buffer | |
GAMMA: float = 0.99 # discount factor | |
EPS_START: float = 0.99 # initial value of epsilon | |
EPS_END: float = 0.1 # final value of epsilon | |
TAU: float = 0.005 # update rate for target network | |
LR: float = 1e-4 # learning rate of the ``AdamW`` optimizer | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
memory: ReplayMemory = ReplayMemory(10000) | |
steps_done: int = 0 | |
def __post_init__(self): | |
self.policy_net: DQN = DQN(self.env.observation_space.shape[0], self.env.action_space.n).to( | |
self.device | |
) | |
self.target_net: DQN = DQN(self.env.observation_space.shape[0], self.env.action_space.n).to( | |
self.device | |
) | |
self.target_net.load_state_dict(self.policy_net.state_dict()) | |
self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=self.LR, amsgrad=True) | |
def select_action(self, state): | |
sample = random.random() | |
# Define parameters | |
EPS_START = self.EPS_START # Initial exploration rate | |
EPS_INTER = self.EPS_START # Exploration rate at the transition point | |
EPS_END = self.EPS_END # Minimum final exploration rate | |
DECAY_STEPS = 25000 # Number of steps in the initial slow decay phase | |
T_MAX = 50000 # Period of each cosine annealing cycle | |
TOTAL_STEPS = 100000 # Total steps to plot | |
PEAK_DECAY_RATE = 0.05 # Rate at which the peaks decrease | |
# i want to make the agent explore more in the beginning | |
# and exploit more towards the end | |
# eps_threshold = hybrid_exploration(self.steps_done, EPS_START, EPS_INTER, EPS_END, DECAY_STEPS, T_MAX, PEAK_DECAY_RATE) | |
eps_threshold = epsilon(self.steps_done, EPS_START, EPS_END, 80000) | |
self.steps_done += 1 | |
if sample > eps_threshold: | |
return self.use_policy(state) | |
else: | |
return torch.tensor([[random.randrange(self.env.action_space.n)]], device=self.device, dtype=torch.long) | |
def use_policy(self, state): | |
state = torch.tensor(state, dtype=torch.float32, device=self.device).view(-1).unsqueeze(0) | |
with torch.no_grad(): | |
return self.policy_net(state).max(1)[1].view(1, 1) | |
def optimize_model(self): | |
if len(self.memory) < self.BATCH_SIZE: | |
return | |
transitions = self.memory.sample(self.BATCH_SIZE) | |
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for | |
# detailed explanation). This converts batch-array of Transitions | |
# to Transition of batch-arrays. | |
batch = Transition(*zip(*transitions)) | |
# Compute a mask of non-final states and concatenate the batch elements | |
# (a final state would've been the one after which the episode ended) | |
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, | |
batch.next_state)), device=self.device, dtype=torch.bool) | |
non_final_next_states = torch.cat([s for s in batch.next_state | |
if s is not None]) | |
state_batch = torch.cat(batch.state) | |
action_batch = torch.cat(batch.action) | |
reward_batch = torch.cat(batch.reward) | |
# Compute Q(s_t, a) - the model computes Q(s_t), then we select | |
# the columns of actions taken. These are the actions which would've been taken | |
# for each batch state according to policy_net | |
state_action_values = self.policy_net(state_batch).gather(1, action_batch) | |
# Compute V(s_{t+1}) for all next states. | |
# Expected values of actions for non_final_next_states are computed based | |
# on the "older" target_net; selecting their best reward with max(1)[0]. | |
# This is merged based on the mask, such that we'll have either the expected | |
# state value or 0 in case the state was final. | |
next_state_values = torch.zeros(self.BATCH_SIZE, device=self.device) | |
with torch.no_grad(): | |
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0] | |
# Compute the expected Q values | |
expected_state_action_values = (next_state_values * self.GAMMA) + reward_batch | |
# Compute Huber loss | |
criterion = nn.SmoothL1Loss() | |
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1)) | |
# Optimize the model | |
self.optimizer.zero_grad() | |
loss.backward() | |
# In-place gradient clipping | |
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 100) | |
self.optimizer.step() | |
def update_weights(self): | |
# Soft update of the target network's weights | |
# θ′ ← τ θ + (1 −τ )θ′ | |
target_net_state_dict = self.target_net.state_dict() | |
policy_net_state_dict = self.policy_net.state_dict() | |
for key in policy_net_state_dict: | |
target_net_state_dict[key] = policy_net_state_dict[key]*self.TAU + target_net_state_dict[key]*(1-self.TAU) | |
self.target_net.load_state_dict(target_net_state_dict) | |
def feed_state(self, state): | |
state = torch.tensor(state, dtype=torch.float32, device=self.device).view(-1).unsqueeze(0) | |
action = self.select_action(state) | |
observation, reward, terminated, *rest = self.env.step(action.item()) | |
reward = torch.tensor([reward], device=self.device) | |
self.reward_history.append(reward.item()) | |
if terminated: | |
next_state = None | |
else: | |
next_state = torch.tensor(observation, dtype=torch.float32, device=self.device).view(-1).unsqueeze(0) | |
# Store the transition in memory | |
self.memory.push(state, action, next_state, reward) | |
# Perform one step of the optimization (on the policy network) | |
self.optimize_model() | |
self.update_weights() | |
return observation, reward, terminated, *rest | |
def save(self, path): | |
torch.save(self.policy_net.state_dict(), path) | |
def load(self, path): | |
self.policy_net.load_state_dict(torch.load(path)) | |
self.policy_net.eval() | |
Author
r33drichards
commented
Jul 22, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment