Created
December 20, 2020 23:02
-
-
Save roclark/5a0761973797086f0a1ce83491cb8d28 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# An updated version of the EpisodicLifeEnv wrapper from RLLib which is | |
# compatible with the SuperMarioBros environments. | |
class EpisodicLifeEnv(gym.Wrapper): | |
def __init__(self, env): | |
"""Make end-of-life == end-of-episode, but only reset on true game | |
over. Done by DeepMind for the DQN and co. since it helps value | |
estimation. | |
""" | |
gym.Wrapper.__init__(self, env) | |
self.lives = 0 | |
self.was_real_done = True | |
def step(self, action): | |
obs, reward, done, info = self.env.step(action) | |
self.was_real_done = done | |
# check current lives, make loss of life terminal, | |
# then update lives to handle bonus lives | |
lives = self.env.unwrapped._life | |
if self.lives > lives > 0: | |
# for Qbert sometimes we stay in lives == 0 condtion for a few fr | |
# so its important to keep lives > 0, so that we only reset once | |
# the environment advertises done. | |
done = True | |
self.lives = lives | |
return obs, reward, done, info | |
def reset(self, **kwargs): | |
"""Reset only when lives are exhausted. | |
This way all states are still reachable even though lives are episodic, | |
and the learner need not know about any of this behind-the-scenes. | |
""" | |
if self.was_real_done: | |
obs = self.env.reset(**kwargs) | |
else: | |
# no-op step to advance from terminal/lost life state | |
obs, _, _, _ = self.env.step(0) | |
self.lives = self.env.unwrapped._life | |
return obs | |
class CustomReward(gym.Wrapper): | |
def __init__(self, env): | |
super(CustomReward, self).__init__(env) | |
self._current_score = 0 | |
def step(self, action): | |
state, reward, done, info = self.env.step(action) | |
reward += (info['score'] - self._current_score) / 40.0 | |
self._current_score = info['score'] | |
if done: | |
if info['flag_get']: | |
reward += 350.0 | |
else: | |
reward -= 50.0 | |
return state, reward / 10.0, done, info |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment