This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' Script for downloading all GLUE data. | |
Note: for legal reasons, we are unable to host MRPC. | |
You can either use the version hosted by the SentEval team, which is already tokenized, | |
or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. | |
For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). | |
You should then rename and place specific files in a folder (see below for an example). | |
mkdir MRPC | |
cabextract MSRParaphraseCorpus.msi -d MRPC |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
def env_creator_lambda(env_config): | |
return env_creator(args.environment, | |
config, | |
args.dimension, | |
args.framestack) | |
args = parse_args() | |
config = { | |
'env': 'super_mario_bros', |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def print_results(result, iteration): | |
table = [['IMPALA', | |
iteration, | |
result['timesteps_total'], | |
round(result['episode_reward_max'], 3), | |
round(result['episode_reward_min'], 3), | |
round(result['episode_reward_mean'], 3)]] | |
print(tabulate(table, | |
headers=['Agent', | |
'Iteration', |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def env_creator(env_name, config, dim, framestack): | |
env = gym_super_mario_bros.make(env_name) | |
env = CustomReward(env) | |
env = JoypadSpace(env, SIMPLE_MOVEMENT) | |
env = MonitorEnv(env) | |
env = NoopResetEnv(env, noop_max=30) | |
env = EpisodicLifeEnv(env) | |
env = WarpFrame(env, dim) | |
if framestack: | |
env = FrameStack(env, framestack) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def parse_args(): | |
parser = ArgumentParser(description='Train an agent to beat Super Mario ' | |
'Bros. levels.') | |
parser.add_argument('--checkpoint', help='Specify an existing checkpoint ' | |
'which can be used to restore progress from a previous' | |
' training run.') | |
parser.add_argument('--dimension', help='The image dimensions to resize to' | |
' while preprocessing the game states.', type=int, | |
default=84) | |
parser.add_argument('--environment', help='The Super Mario Bros level to ' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# An updated version of the EpisodicLifeEnv wrapper from RLLib which is | |
# compatible with the SuperMarioBros environments. | |
class EpisodicLifeEnv(gym.Wrapper): | |
def __init__(self, env): | |
"""Make end-of-life == end-of-episode, but only reset on true game | |
over. Done by DeepMind for the DQN and co. since it helps value | |
estimation. | |
""" | |
gym.Wrapper.__init__(self, env) | |
self.lives = 0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import gym_super_mario_bros | |
import ray | |
from argparse import ArgumentParser | |
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT | |
from nes_py.wrappers import JoypadSpace | |
from ray import tune | |
from ray.tune.registry import register_env | |
from ray.rllib.agents.impala import ImpalaTrainer | |
from ray.rllib.env.atari_wrappers import (MonitorEnv, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import gym_super_mario_bros | |
import ray | |
from argparse import ArgumentParser | |
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT | |
from nes_py.wrappers import JoypadSpace | |
from ray import tune | |
from ray.tune.registry import register_env | |
from ray.rllib.agents.impala import ImpalaTrainer | |
from ray.rllib.env.atari_wrappers import (MonitorEnv, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from argparse import ArgumentParser | |
from os.path import isfile | |
from sportsreference.ncaab.teams import Teams | |
from sklearn.ensemble import RandomForestRegressor | |
from sklearn.model_selection import train_test_split | |
DATASET_NAME = 'dataset.pkl' | |
FIELDS_TO_DROP = ['away_points', 'home_points', 'date', 'location', | |
'losing_abbr', 'losing_name', 'winner', 'winning_abbr', |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = RandomForestRegressor(bootstrap=False, | |
min_samples_leaf=3, | |
... | |
max_depth=6) |
NewerOlder