Last active
January 27, 2025 22:44
-
-
Save ejmejm/1eba885ac48eca45d234bc19fd9f86b1 to your computer and use it in GitHub Desktop.
Example quick experiment format
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
device: cuda | |
wandb: | |
enabled: false | |
project: minigrid_gvfs | |
optimizer: | |
type: sgd | |
meta_lr: 0.01 | |
#... | |
#... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### Google style, ordered imports ### | |
import copy | |
import math | |
from typing import Iterator, List, Optional, Tuple | |
import hydra | |
import omegaconf | |
import torch | |
import wandb | |
### Local imports if relevant... ### | |
### GLOBAL VARIABLES ### | |
TRAJECTORY_CACHE = {} | |
### Classes ### | |
### Functions ### | |
def setup_optimizer(params: Iterator[torch.Tensor], config: DictConfig) -> Optimizer: | |
"""Setup optimizers for each GVF. | |
Args: | |
params: iterator over all parameters of the model being optimized. | |
config: top level config. | |
Returns: | |
Optimizer of type dependent on config values. | |
""" | |
if config.optimizer.type.lower() == 'idbd': | |
return IDBD( | |
params, | |
meta_lr = config.optimizer.meta_lr, | |
init_lr = config.optimizer.init_lr, | |
tau = config.optimizer.get('tau', 1e4), | |
weight_decay = config.optimizer.get('weight_decay', 0.0), | |
autostep = False, | |
) | |
elif config.optimizer.type.lower() == 'autostep': | |
return IDBD( | |
params, | |
meta_lr = config.optimizer.meta_lr, | |
init_lr = config.optimizer.init_lr, | |
tau = config.optimizer.get('tau', 1e4), | |
weight_decay = config.optimizer.get('weight_decay', 0.0), | |
autostep = True, | |
) | |
elif config.optimizer.type.lower() == 'adam': | |
return optim.Adam(params, lr=config.optimizer.init_lr) | |
elif config.optimizer.type.lower() == 'sgd': | |
return optim.SGD(params, lr=config.optimizer.init_lr) | |
else: | |
raise ValueError(f'Unknown optimizer type: {config.optimizer.type}') | |
def setup_networks( | |
obs_dim: int, | |
target_policy_names: List[str], | |
config: DictConfig, | |
) -> List[GVF]: | |
# ... | |
pass | |
def train_loop( | |
env: gym.Env, | |
behavior_policy: callable, | |
gvfs: List[GVF], | |
training_algo: TrainingAlgorithm, | |
config: DictConfig | |
): | |
# ... | |
pass | |
def setup_wandb(config: DictConfig) -> None: | |
"""Initialize Weights & Biases logging.""" | |
wandb_mode = 'online' if config.wandb.enabled else 'disabled' | |
wandb.init(project=config.wandb.project, mode=wandb_mode) | |
wandb.config.update(omegaconf.OmegaConf.to_container( | |
config, resolve=True, throw_on_missing=True)) | |
return config | |
### Main setup and experiment launch code ### | |
@hydra.main(config_path='conf', config_name='defaults') | |
def main(config: DictConfig) -> None: | |
setup_wandb(config) | |
# Setup environment, data, models, etc... | |
# Run training | |
train_loop( | |
env = env, | |
behavior_policy = behavior_policy, | |
gvfs = gvfs, | |
training_algo = training_alg, | |
config = config, | |
) | |
if config.wandb.enabled: | |
wandb.finish() | |
if __name__ == '__main__': | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
program: run_experiment.py | |
method: grid | |
project: project-name | |
name: example-sweep-name | |
command: | |
- ${env} | |
- python | |
- ${program} | |
- ${args_no_hyphens} | |
metric: | |
name: loss | |
goal: minimize | |
parameters: | |
wandb.enabled: | |
value: True | |
device: | |
value: cuda | |
optimizer: | |
values: ['sgd', 'adam', 'rmsprop'] | |
seed: | |
values: [0, 1, 2, 3, 4] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment