Skip to content

Instantly share code, notes, and snippets.

@roclark
Created December 20, 2020 23:13
Show Gist options
  • Save roclark/708e9d9aa5dba9a5a1d2a41d225c333b to your computer and use it in GitHub Desktop.
Save roclark/708e9d9aa5dba9a5a1d2a41d225c333b to your computer and use it in GitHub Desktop.
def main():
def env_creator_lambda(env_config):
return env_creator(args.environment,
config,
args.dimension,
args.framestack)
args = parse_args()
config = {
'env': 'super_mario_bros',
'framework': 'torch',
'rollout_fragment_length': 50,
'train_batch_size': 500,
'num_workers': args.workers,
'num_envs_per_worker': 1,
'num_gpus': args.gpus
}
ray.init()
register_env('super_mario_bros', env_creator_lambda)
trainer = ImpalaTrainer(config=config)
if args.checkpoint:
trainer.restore(args.checkpoint)
for iteration in range(args.iterations):
result = trainer.train()
print_results(result, iteration)
if iteration % 50 == 0:
checkpoint = trainer.save()
print('Checkpoint saved at', checkpoint)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment