Created
January 24, 2017 07:06
-
-
Save denisyarats/be8856e41a1fd68cb5353f6344311237 to your computer and use it in GitHub Desktop.
q-learning with linear approximation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/local/bin/python | |
""" | |
Q-learning with value fucntion approximation | |
""" | |
import argparse | |
import numpy as np | |
import matplotlib | |
from matplotlib import pyplot as plt | |
from mpl_toolkits.mplot3d import Axes3D | |
from collections import defaultdict | |
import gym | |
from gym import wrappers | |
import pdb | |
from sklearn.pipeline import Pipeline, FeatureUnion | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.linear_model import SGDRegressor | |
from sklearn.kernel_approximation import RBFSampler | |
EXP_NAME_PREFIX = 'exp/q_learning_vfa' | |
API_KEY = 'sk_ARsYZ2eRsGoeANVhUgrQ' | |
ENVS = { | |
'mountaincar': 'MountainCar-v0', # --env mountaincar --gamma 0.99 --eps 0.3 --goal -110 --upload --max_episodes 10000 --eps_schedule 500 | |
} | |
class FeaturesMaker(object): | |
def __init__(self): | |
self.pipeline = Pipeline([ | |
('scale', StandardScaler()), | |
('rbf', FeatureUnion([ | |
('rbf5', RBFSampler(gamma=5.0, n_components=100)), | |
('rbf2', RBFSampler(gamma=2.0, n_components=100)), | |
('rbf1', RBFSampler(gamma=1.0, n_components=100)), | |
('rbf05', RBFSampler(gamma=0.5, n_components=100)), | |
])) | |
]) | |
def fit(self, X): | |
return self.pipeline.fit(X) | |
def transform(self, X): | |
return self.pipeline.transform(X) | |
class ValueFunction(object): | |
def __init__(self, F, nA): | |
self.F = F | |
self.models = [SGDRegressor(learning_rate='constant') for _ in xrange(nA)] | |
def predict(self, s): | |
f = self.F.transform([s]) | |
return np.array([m.predict(f)[0] for m in self.models]) | |
def update(self, s, a, t): | |
f = self.F.transform([s]) | |
self.models[a].partial_fit(f, [t]) | |
def q_learning(env, V, max_episodes, gamma, eps, eps_schedule, goal): | |
nA = env.action_space.n | |
# init | |
for a in xrange(nA): | |
V.update(env.observation_space.sample(), a, 0) | |
P = np.zeros(nA, np.float32) | |
tR = np.ones(100, np.float32) * (-1000) | |
for e in xrange(max_episodes): | |
if e % eps_schedule == 0 and e > 0: | |
eps /= 2 | |
s = env.reset() | |
done = False | |
tR[e % tR.size] = 0. | |
nit = 0 | |
while not done: | |
nit += 1 | |
P.fill(eps / nA) | |
P[np.argmax(V.predict(s))] += 1 - eps | |
a = np.random.choice(xrange(nA), p=P) | |
ns, r, done, _ = env.step(a) | |
t = r + gamma * np.max(V.predict(ns)) | |
V.update(s, a, t) | |
s = ns | |
tR[e % tR.size] += r | |
print 'episode %d, iterations %d, average reward: %.3f' % (e, nit, np.mean(tR)) | |
if np.mean(tR) > goal: | |
return e | |
return max_episodes | |
def main(): | |
parser = argparse.ArgumentParser(description='Q-learning with VF approximation') | |
parser.add_argument('--env', choices=ENVS.keys()) | |
parser.add_argument('--max_episodes', type=int, default=10000) | |
parser.add_argument('--gamma', type=float, default=1.0) | |
parser.add_argument('--eps', type=float, default=0.0) | |
parser.add_argument('--eps_schedule', type=int, default=10000) | |
parser.add_argument('--goal', type=float, default=1.0) | |
parser.add_argument('--upload', action='store_true', default=False) | |
args = parser.parse_args() | |
exp_name = '%s_%s' % (EXP_NAME_PREFIX, args.env) | |
env = gym.make(ENVS[args.env]) | |
env.seed(0) | |
np.random.seed(0) | |
if args.upload: | |
env = wrappers.Monitor(env, exp_name, force=True) | |
F = FeaturesMaker() | |
X = np.array([env.observation_space.sample() for _ in xrange(10000)]) | |
F.fit(X) | |
V = ValueFunction(F, env.action_space.n) | |
res = q_learning(env, V, args.max_episodes, args.gamma, | |
args.eps, args.eps_schedule, args.goal) | |
print 'result -> %d' % res | |
env.close() | |
if args.upload: | |
gym.upload(exp_name, api_key=API_KEY) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment