Skip to content

Instantly share code, notes, and snippets.

@r33drichards
Created July 22, 2024 17:53
Show Gist options
  • Save r33drichards/fa261eb51209391070aa602e7da0ed5d to your computer and use it in GitHub Desktop.
Save r33drichards/fa261eb51209391070aa602e7da0ed5d to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import defaultdict
from dataclasses import dataclass, field
import gymnasium as gym
from env import TicTacToeEnv
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))
class ReplayMemory(object):
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, *args):
"""Save a transition"""
self.memory.append(Transition(*args))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class DQN(nn.Module):
def __init__(self, n_observations, n_actions):
super(DQN, self).__init__()
self.layer1 = nn.Linear(n_observations, n_observations*2)
self.layer2 = nn.Linear(n_observations*2, n_observations*2)
self.layer3 = nn.Linear(n_observations*2, n_actions)
# Called with either one element to determine next action, or a batch
# during optimization. Returns tensor([[left0exp,right0exp]...]).
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
return self.layer3(x)
import numpy as np
import matplotlib.pyplot as plt
# Define parameters
import numpy as np
import matplotlib.pyplot as plt
# Define the function to calculate epsilon
def epsilon(steps_done, EPS_START, EPS_END, DECAY_RATE):
return EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / DECAY_RATE)
# Define the hybrid exploration rate function
def hybrid_exploration(steps_done, EPS_START, EPS_INTER, EPS_END, DECAY_STEPS, T_MAX, PEAK_DECAY_RATE):
DECAY_RATE = 800000 # Decay rate
if steps_done < DECAY_STEPS:
# exponential decay phase
return EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / DECAY_RATE)
else:
# Adjusted amplitude for cosine annealing that decreases over steps
adjusted_max = EPS_INTER - (EPS_INTER - EPS_END) * ((steps_done - DECAY_STEPS) / 1000000)
current_amplitude = max(adjusted_max - EPS_END, 0)
# Cosine annealing phase with decreasing peaks
return EPS_END + 0.5 * current_amplitude * (1 + np.cos(np.pi * (steps_done - DECAY_STEPS) % T_MAX / T_MAX))
# Generate data for steps_done
# steps = np.arange(0, 100_000) # For example, from 0 to 2000 steps to show two cycles
# epsilons = [cosine_annealing(step, EPS_START, EPS_END, T_MAX) for step in steps]
# # Plotting
# plt.figure(figsize=(10, 5))
# plt.plot(steps, epsilons, label='Epsilon')
# plt.title('Cosine Annealing for Exploration Rate')
# plt.xlabel('Steps Done')
# plt.ylabel('Epsilon')
# plt.grid(True)
# plt.legend()
# plt.show()
@dataclass
class Agent:
env: gym.Env
reward_history: list = field(default_factory=list)
BATCH_SIZE: int = 1000 # number of transitions sampled from replay buffer
GAMMA: float = 0.99 # discount factor
EPS_START: float = 0.99 # initial value of epsilon
EPS_END: float = 0.1 # final value of epsilon
TAU: float = 0.005 # update rate for target network
LR: float = 1e-4 # learning rate of the ``AdamW`` optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
memory: ReplayMemory = ReplayMemory(10000)
steps_done: int = 0
def __post_init__(self):
self.policy_net: DQN = DQN(self.env.observation_space.shape[0], self.env.action_space.n).to(
self.device
)
self.target_net: DQN = DQN(self.env.observation_space.shape[0], self.env.action_space.n).to(
self.device
)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=self.LR, amsgrad=True)
def select_action(self, state):
sample = random.random()
# Define parameters
EPS_START = self.EPS_START # Initial exploration rate
EPS_INTER = self.EPS_START # Exploration rate at the transition point
EPS_END = self.EPS_END # Minimum final exploration rate
DECAY_STEPS = 25000 # Number of steps in the initial slow decay phase
T_MAX = 50000 # Period of each cosine annealing cycle
TOTAL_STEPS = 100000 # Total steps to plot
PEAK_DECAY_RATE = 0.05 # Rate at which the peaks decrease
# i want to make the agent explore more in the beginning
# and exploit more towards the end
# eps_threshold = hybrid_exploration(self.steps_done, EPS_START, EPS_INTER, EPS_END, DECAY_STEPS, T_MAX, PEAK_DECAY_RATE)
eps_threshold = epsilon(self.steps_done, EPS_START, EPS_END, 80000)
self.steps_done += 1
if sample > eps_threshold:
return self.use_policy(state)
else:
return torch.tensor([[random.randrange(self.env.action_space.n)]], device=self.device, dtype=torch.long)
def use_policy(self, state):
state = torch.tensor(state, dtype=torch.float32, device=self.device).view(-1).unsqueeze(0)
with torch.no_grad():
return self.policy_net(state).max(1)[1].view(1, 1)
def optimize_model(self):
if len(self.memory) < self.BATCH_SIZE:
return
transitions = self.memory.sample(self.BATCH_SIZE)
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
# detailed explanation). This converts batch-array of Transitions
# to Transition of batch-arrays.
batch = Transition(*zip(*transitions))
# Compute a mask of non-final states and concatenate the batch elements
# (a final state would've been the one after which the episode ended)
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
batch.next_state)), device=self.device, dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state
if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
# Compute Q(s_t, a) - the model computes Q(s_t), then we select
# the columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = self.policy_net(state_batch).gather(1, action_batch)
# Compute V(s_{t+1}) for all next states.
# Expected values of actions for non_final_next_states are computed based
# on the "older" target_net; selecting their best reward with max(1)[0].
# This is merged based on the mask, such that we'll have either the expected
# state value or 0 in case the state was final.
next_state_values = torch.zeros(self.BATCH_SIZE, device=self.device)
with torch.no_grad():
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]
# Compute the expected Q values
expected_state_action_values = (next_state_values * self.GAMMA) + reward_batch
# Compute Huber loss
criterion = nn.SmoothL1Loss()
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
# Optimize the model
self.optimizer.zero_grad()
loss.backward()
# In-place gradient clipping
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 100)
self.optimizer.step()
def update_weights(self):
# Soft update of the target network's weights
# θ′ ← τ θ + (1 −τ )θ′
target_net_state_dict = self.target_net.state_dict()
policy_net_state_dict = self.policy_net.state_dict()
for key in policy_net_state_dict:
target_net_state_dict[key] = policy_net_state_dict[key]*self.TAU + target_net_state_dict[key]*(1-self.TAU)
self.target_net.load_state_dict(target_net_state_dict)
def feed_state(self, state):
state = torch.tensor(state, dtype=torch.float32, device=self.device).view(-1).unsqueeze(0)
action = self.select_action(state)
observation, reward, terminated, *rest = self.env.step(action.item())
reward = torch.tensor([reward], device=self.device)
self.reward_history.append(reward.item())
if terminated:
next_state = None
else:
next_state = torch.tensor(observation, dtype=torch.float32, device=self.device).view(-1).unsqueeze(0)
# Store the transition in memory
self.memory.push(state, action, next_state, reward)
# Perform one step of the optimization (on the policy network)
self.optimize_model()
self.update_weights()
return observation, reward, terminated, *rest
def save(self, path):
torch.save(self.policy_net.state_dict(), path)
def load(self, path):
self.policy_net.load_state_dict(torch.load(path))
self.policy_net.eval()
@r33drichards
Copy link
Author

#!/usr/bin/env python3
import gymnasium as gym
from agent import Agent
from itertools import count

from tqdm import tqdm


def train(num_episodes, f=None, render_mode="human"):
    env = gym.make("Breakout-ram-v4", render_mode=render_mode, obs_type="ram")

    agent = Agent(
    env=env,
    )
    if f is not None:
        agent.load(f)
    observation, info = env.reset(seed=42)
    
    for e in tqdm(range(num_episodes)):
        observation, reward, terminated, *rest = agent.feed_state(observation)
        if terminated:
            agent.save(f"bo.{e}.pt")
            observation, info = env.reset()

    agent.save("breakout.pt")

            

    env.close()


def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("mode", help="train or play", choices=["train", "play"])
    # num_episodes = 301
    parser.add_argument("--num_episodes", "-n",  help="number of episodes to train", type=int, default=301)
    parser.add_argument("--file", "-f",  help="file to load", type=str, default=None)
    # render-mode
    parser.add_argument("--render_mode", "-r",  help="render mode", type=str, default=None)
    return parser.parse_args()

def play(f = "lunar.pt"):
    env = gym.make("Breakout-ram-v4", render_mode="human")

    agent = Agent(
        env=env,
    )
    agent.load(f)
    state, _ = env.reset()
    z_count = 0
    last_reward = None

    for t in count():
        env.render()
        action = agent.use_policy(state)
        state, reward, terminated, *info = env.step(action.item())
        print(action.item())
        if last_reward is None:
            last_reward = reward
        elif last_reward == reward:
            z_count += 1
        else:
            z_count = 0
            last_reward = reward

        
        if z_count > 300:
            z_count = 0
            terminated = True
        if terminated:
            break
        
    env.close()
    print("Game over")
    print("lasted {} steps".format(t))



if __name__ == '__main__':
    args = parse_args()
    if args.mode == "train":
        num_episodes = args.num_episodes
        f = args.file
        render_mode = args.render_mode
        train(num_episodes, f, render_mode)
    elif args.mode == "play":
        play(args.file)
    else:
        raise ValueError("Invalid mode")
        

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment