Last active
March 22, 2017 15:11
-
-
Save denisyarats/4b9e7a2d2c53edd4574555ab86f08d12 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import pdb | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.autograd import Variable | |
import torch.nn.functional as F | |
import numpy as np | |
import gym | |
from gym import wrappers | |
EXP_NAME_PREFIX = 'exp/' | |
API_KEY = '***' | |
ENVS = { | |
'copy': 'Copy-v0', | |
'repeated_copy': 'RepeatCopy-v0', | |
'duplicated_input': 'DuplicatedInput-v0', | |
} | |
class Policy(nn.Module): | |
def __init__(self, ninput, nembed, nhid, noutput): | |
super(Policy, self).__init__() | |
self.embed = nn.Embedding(ninput, nembed) | |
self.lstm = nn.LSTM(nembed, nhid) | |
self.fc = nn.Linear(nhid, noutput) | |
self.init_weights() | |
def init_weights(self): | |
self.embed.weight.data.uniform_(-0.1, 0.1) | |
self.fc.weight.data.uniform_(-0.1, 0.1) | |
self.fc.bias.data.fill_(0.0) | |
def forward(self, x, hx, cx): | |
x = self.embed(x) | |
x, (hx, cx) = self.lstm(x, (hx, cx)) | |
x = self.fc(x.view(-1, x.size(2))) | |
return x, hx, cx | |
def policy_gradient(E, args): | |
nS = int(E.observation_space.n) | |
dA = [d.n for d in E.action_space.spaces] | |
nA = int(np.prod(dA)) | |
model = Policy(nS, args.nembed, args.nhid, nA) | |
opt = optim.Adam(model.parameters(), lr=args.lr) | |
def decode(a): | |
tu = [] | |
for d in reversed(dA): | |
tu.append(a % d) | |
a //= d | |
return tu[::-1] | |
tr = np.zeros(100, np.float32) | |
for e in range(args.max_episodes): | |
if e % 50 == 0 and e > 0: | |
print('episode %d, average reward: %.3f' % (e, np.mean(tr))) | |
if np.mean(tr) > args.goal: | |
return e | |
s = E.reset() | |
hx = Variable(torch.zeros(1, 1, args.nhid)) | |
cx = Variable(torch.zeros(1, 1, args.nhid)) | |
done = False | |
tr[e % tr.size] = 0.0 | |
lps = [] | |
rs = [] | |
while not done: | |
logit, hx, cx = model(Variable(torch.LongTensor(1, 1).fill_(s)), hx, cx) | |
p = F.softmax(logit) | |
lp = F.log_softmax(logit) | |
a = np.random.choice(range(nA), p=p.squeeze(0).data.numpy()) | |
ns, r, done, _ = E.step(decode(a)) | |
if args.render: | |
E.render() | |
tr[e % tr.size] += r | |
s = ns | |
lps.append(lp.gather(1, Variable(torch.LongTensor(1, 1).fill_(int(a))))) | |
rs.append(r) | |
R = Variable(torch.zeros(1, 1)) | |
loss = 0 | |
for lp, r in zip(reversed(lps), reversed(rs)): | |
R = args.gamma * R + r | |
loss -= lp * R | |
opt.zero_grad() | |
loss.backward() | |
nn.utils.clip_grad_norm(model.parameters(), args.clip) | |
opt.step() | |
return args.max_episodes | |
def main(): | |
parser = argparse.ArgumentParser(description='REINFORCE') | |
parser.add_argument('--env', choices=ENVS.keys()) | |
parser.add_argument('--max_episodes', type=int, default=100000) | |
parser.add_argument('--gamma', type=float, default=0.99) | |
parser.add_argument('--lr', type=float, default=0.001) | |
parser.add_argument('--nembed', type=int, default=64) | |
parser.add_argument('--nhid', type=int, default=128) | |
parser.add_argument('--goal', type=float, default=1.0) | |
parser.add_argument('--clip', type=float, default=20.0) | |
parser.add_argument('--upload', action='store_true', default=False) | |
parser.add_argument('--render', action='store_true', default=False) | |
parser.add_argument('--seed', type=int, default=1, help='random seed') | |
parser.add_argument('--cuda', action='store_true', default=False) | |
args = parser.parse_args() | |
E = gym.make(ENVS[args.env]) | |
exp_name = 'exp/%s' % args.env | |
E.seed(args.seed) | |
torch.manual_seed(args.seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(args.seed) | |
np.random.seed(args.seed) | |
if args.upload: | |
E = wrappers.Monitor(E, exp_name, force=True) | |
policy_gradient(E, args) | |
E.close() | |
if args.upload: | |
gym.upload(exp_name, api_key=API_KEY) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment