Skip to content

Instantly share code, notes, and snippets.

@ejmejm
Last active January 27, 2025 22:44
Show Gist options
  • Save ejmejm/1eba885ac48eca45d234bc19fd9f86b1 to your computer and use it in GitHub Desktop.
Save ejmejm/1eba885ac48eca45d234bc19fd9f86b1 to your computer and use it in GitHub Desktop.
Example quick experiment format
device: cuda
wandb:
enabled: false
project: minigrid_gvfs
optimizer:
type: sgd
meta_lr: 0.01
#...
#...
### Google style, ordered imports ###
import copy
import math
from typing import Iterator, List, Optional, Tuple
import hydra
import omegaconf
import torch
import wandb
### Local imports if relevant... ###
### GLOBAL VARIABLES ###
TRAJECTORY_CACHE = {}
### Classes ###
### Functions ###
def setup_optimizer(params: Iterator[torch.Tensor], config: DictConfig) -> Optimizer:
"""Setup optimizers for each GVF.
Args:
params: iterator over all parameters of the model being optimized.
config: top level config.
Returns:
Optimizer of type dependent on config values.
"""
if config.optimizer.type.lower() == 'idbd':
return IDBD(
params,
meta_lr = config.optimizer.meta_lr,
init_lr = config.optimizer.init_lr,
tau = config.optimizer.get('tau', 1e4),
weight_decay = config.optimizer.get('weight_decay', 0.0),
autostep = False,
)
elif config.optimizer.type.lower() == 'autostep':
return IDBD(
params,
meta_lr = config.optimizer.meta_lr,
init_lr = config.optimizer.init_lr,
tau = config.optimizer.get('tau', 1e4),
weight_decay = config.optimizer.get('weight_decay', 0.0),
autostep = True,
)
elif config.optimizer.type.lower() == 'adam':
return optim.Adam(params, lr=config.optimizer.init_lr)
elif config.optimizer.type.lower() == 'sgd':
return optim.SGD(params, lr=config.optimizer.init_lr)
else:
raise ValueError(f'Unknown optimizer type: {config.optimizer.type}')
def setup_networks(
obs_dim: int,
target_policy_names: List[str],
config: DictConfig,
) -> List[GVF]:
# ...
pass
def train_loop(
env: gym.Env,
behavior_policy: callable,
gvfs: List[GVF],
training_algo: TrainingAlgorithm,
config: DictConfig
):
# ...
pass
def setup_wandb(config: DictConfig) -> None:
"""Initialize Weights & Biases logging."""
wandb_mode = 'online' if config.wandb.enabled else 'disabled'
wandb.init(project=config.wandb.project, mode=wandb_mode)
wandb.config.update(omegaconf.OmegaConf.to_container(
config, resolve=True, throw_on_missing=True))
return config
### Main setup and experiment launch code ###
@hydra.main(config_path='conf', config_name='defaults')
def main(config: DictConfig) -> None:
setup_wandb(config)
# Setup environment, data, models, etc...
# Run training
train_loop(
env = env,
behavior_policy = behavior_policy,
gvfs = gvfs,
training_algo = training_alg,
config = config,
)
if config.wandb.enabled:
wandb.finish()
if __name__ == '__main__':
main()
program: run_experiment.py
method: grid
project: project-name
name: example-sweep-name
command:
- ${env}
- python
- ${program}
- ${args_no_hyphens}
metric:
name: loss
goal: minimize
parameters:
wandb.enabled:
value: True
device:
value: cuda
optimizer:
values: ['sgd', 'adam', 'rmsprop']
seed:
values: [0, 1, 2, 3, 4]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment