Skip to content

Instantly share code, notes, and snippets.

@nilsleh
Created September 17, 2024 12:17
Show Gist options
  • Save nilsleh/18ab0816ea5cd38ffd843040a8970e43 to your computer and use it in GitHub Desktop.
Save nilsleh/18ab0816ea5cd38ffd843040a8970e43 to your computer and use it in GitHub Desktop.
Hyperparam tuning with optuna
experiment:
experiment_name: "eurosat"
exp_dir: "/mnt/SSD2/nils/projects/hyper_param/results"
wandb:
project: eurosat_hparam
entity: nleh
mode: offline
model:
_target_: lightning_uq_box.uq_methods.DeterministicClassification
model:
_target_: timm.create_model
model_name: resnet18
num_classes: 10
drop_rate: 0.1
loss_fn:
_target_: torch.nn.CrossEntropyLoss
optimizer:
_target_: torch.optim.AdamW
_partial_: True
lr: 0.003
weight_decay: 0.0001
lr_scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
_partial_: True
T_max: 100
datamodule:
_target_: torchgeo.datamodules.EuroSATDataModule
root: /mnt/SSD2/nils/projects/hyper_param/data
batch_size: 128
num_workers: 4
download: True
bands: ["B04", "B03", "B02"]
trainer:
_target_: lightning.Trainer
max_epochs: 100
accelerator: gpu
devices: [4]
gradient_clip_val: 1.0
"""Hyperparameter optimization using Optuna and PyTorch Lightning."""
import argparse
import os
from typing import List
from typing import Optional
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from packaging import version
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import datasets
from torchvision import transforms
from optuna.samplers import TPESampler
from lightning import Trainer
import argparse
import os
from datetime import datetime
from typing import Any
import torch
from hydra.utils import instantiate
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from omegaconf import OmegaConf
def create_experiment_dir(config: dict[str, Any]) -> str:
"""Create experiment directory.
Args:
config: config file
Returns:
config with updated save_dir
"""
os.makedirs(config["experiment"]["exp_dir"], exist_ok=True)
exp_dir_name = (
f"{config['experiment']['experiment_name']}"
f"_{datetime.now().strftime('%m-%d-%Y_%H-%M-%S-%f')}"
)
config["experiment"]["experiment_name"] = exp_dir_name
exp_dir_path = os.path.join(config["experiment"]["exp_dir"], exp_dir_name)
os.makedirs(exp_dir_path)
config["experiment"]["save_dir"] = exp_dir_path
config["trainer"]["default_root_dir"] = exp_dir_path
return config
def generate_trainer(config: dict[str, Any], trial) -> Trainer:
"""Generate a pytorch lightning trainer."""
loggers = [
CSVLogger(config["experiment"]["save_dir"], name="csv_logs"),
WandbLogger(
name=config["experiment"]["experiment_name"],
save_dir=config["experiment"]["save_dir"],
project=config["wandb"]["project"],
entity=config["wandb"]["entity"],
resume="allow",
mode=config["wandb"]["mode"],
),
]
lr_monitor_callback = LearningRateMonitor(logging_interval="step")
callbacks = [
lr_monitor_callback,
PyTorchLightningPruningCallback(trial, monitor="valAcc"),
]
return instantiate(
config.trainer,
default_root_dir=config["experiment"]["save_dir"],
callbacks=callbacks,
logger=loggers,
)
def objective(trial: optuna.trial.Trial, config) -> float:
# MODEL HYPERPARAMETERS to tune with optuna
config.model.model.drop_rate = trial.suggest_float("dropout", 0.1, 0.5)
config.model.optimizer.lr = trial.suggest_float(
"learning_rate", 1e-5, 1e-1, log=True
)
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD", "AdamW"])
if optimizer_name == "Adam":
config.model.optimizer._target_ = "torch.optim.Adam"
elif optimizer_name == "SGD":
config.model.optimizer._target_ = "torch.optim.SGD"
config.model.optimizer.weight_decay = trial.suggest_float(
"weight_decay", 1e-5, 1e-3, log=True
)
config.model.optimizer.momentum = trial.suggest_float("momentum", 0.8, 0.99)
elif optimizer_name == "AdamW":
config.model.optimizer._target_ = "torch.optim.AdamW"
config.model.optimizer.weight_decay = trial.suggest_float(
"weight_decay", 1e-5, 1e-3, log=True
)
lr_schedule = trial.suggest_categorical(
"lr_schedule", ["StepLR", "ExponentialLR", "CosineAnnealingLR", "null"]
)
if lr_schedule == "StepLR":
config.model.lr_scheduler = {
"_target_": "torch.optim.lr_scheduler.StepLR",
"_partial_": True,
"step_size": trial.suggest_int("step_size", 1, 10),
"gamma": trial.suggest_float("gamma", 0.1, 0.9),
}
elif lr_schedule == "ExponentialLR":
config.model.lr_scheduler = {
"_target_": "torch.optim.lr_scheduler.ExponentialLR",
"_partial_": True,
"gamma": trial.suggest_float("gamma", 0.1, 0.9),
}
elif lr_schedule == "CosineAnnealingLR":
config.model.lr_scheduler = {
"_target_": "torch.optim.lr_scheduler.CosineAnnealingLR",
"_partial_": True,
"T_max": trial.suggest_int("T_max", 10, 50),
}
elif lr_schedule == "null":
config.model.lr_scheduler = None
model = instantiate(config.model)
model.input_key = "image"
model.target_key = "label"
# DataModule parameters
datamodule = instantiate(config.datamodule)
trainer = generate_trainer(config, trial)
trainer.fit(model, datamodule=datamodule)
return trainer.callback_metrics["valAcc"].item()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PyTorch Lightning example.")
parser.add_argument(
"--pruning",
"-p",
action="store_true",
help="Activate the pruning feature. `MedianPruner` stops unpromising "
"trials at the early stages of training.",
)
parser.add_argument(
"--config",
"-c",
)
args = parser.parse_args()
pruner = (
optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
)
config = OmegaConf.load(args.config)
config = create_experiment_dir(config)
study = optuna.create_study(
direction="maximize", pruner=pruner, sampler=TPESampler()
)
study.optimize(lambda trial: objective(trial, config), n_trials=100, timeout=600)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment