Created
March 26, 2013 20:39
-
-
Save oddskool/5249033 to your computer and use it in GitHub Desktop.
epsilon greedy algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
class EpsilonGreedyBandit(Bandit): | |
""" | |
The best action (as much as the algorithm knows so far) is selected for | |
a proportion 1 - \epsilon of the trials, and another action is randomly | |
selected (with uniform probability) for a proportion \epsilon. | |
Parameters | |
---------- | |
epsilon : float in [0,1], proportion of trials in which we select a | |
random action | |
""" | |
def __init__(self, epsilon=0.1): | |
# raises TypeError if conversion not supported | |
self.epsilon = float(epsilon) | |
if self.epsilon < 0.0 or epsilon > 1.0: | |
raise ValueError('epsilon must be in range [0,1]') | |
self.actions_played = dict() | |
self.actions_rewarded = dict() | |
def _expected_reward(self, action): | |
if action not in self.actions_played: | |
raise ValueError('action %s never played'%repr(action)) | |
mean_reward = self.actions_rewarded.get(action,0.0) | |
mean_reward /= self.actions_played[action] | |
return mean_reward | |
def recommend(self, visitor, possible_actions): | |
if random.uniform(0.0,1.0) < self.epsilon: | |
return random.choice(possible_actions) | |
# sort actions by expected reward and selects best | |
# mean rewarded action | |
action_rewards = [ (self._expected_reward(action),action) for | |
action in self.actions_played.iterkeys() ] | |
action_rewards.sort(reverse=True) | |
chosen_action = action_rewards[0][1] | |
self.actions_played[chosen_action] = self.actions_played.get(chosen_action, | |
0.0) + 1.0 | |
return chosen_action | |
def update(self, visitor, action, reward): | |
# raises TypeError if conversion not supported | |
reward = float(reward) | |
if reward < 0.0 or reward > 1.0: | |
raise ValueError('reward must be in range [0,1]') | |
self.actions_rewarded[action] = self.actions_rewarded.get(action,0.0) + 1.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment