Created
September 17, 2024 12:17
-
-
Save nilsleh/18ab0816ea5cd38ffd843040a8970e43 to your computer and use it in GitHub Desktop.
Hyperparam tuning with optuna
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
experiment: | |
experiment_name: "eurosat" | |
exp_dir: "/mnt/SSD2/nils/projects/hyper_param/results" | |
wandb: | |
project: eurosat_hparam | |
entity: nleh | |
mode: offline | |
model: | |
_target_: lightning_uq_box.uq_methods.DeterministicClassification | |
model: | |
_target_: timm.create_model | |
model_name: resnet18 | |
num_classes: 10 | |
drop_rate: 0.1 | |
loss_fn: | |
_target_: torch.nn.CrossEntropyLoss | |
optimizer: | |
_target_: torch.optim.AdamW | |
_partial_: True | |
lr: 0.003 | |
weight_decay: 0.0001 | |
lr_scheduler: | |
_target_: torch.optim.lr_scheduler.CosineAnnealingLR | |
_partial_: True | |
T_max: 100 | |
datamodule: | |
_target_: torchgeo.datamodules.EuroSATDataModule | |
root: /mnt/SSD2/nils/projects/hyper_param/data | |
batch_size: 128 | |
num_workers: 4 | |
download: True | |
bands: ["B04", "B03", "B02"] | |
trainer: | |
_target_: lightning.Trainer | |
max_epochs: 100 | |
accelerator: gpu | |
devices: [4] | |
gradient_clip_val: 1.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Hyperparameter optimization using Optuna and PyTorch Lightning.""" | |
import argparse | |
import os | |
from typing import List | |
from typing import Optional | |
import optuna | |
from optuna.integration import PyTorchLightningPruningCallback | |
from packaging import version | |
import torch | |
from torch import nn | |
from torch import optim | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from torch.utils.data import random_split | |
from torchvision import datasets | |
from torchvision import transforms | |
from optuna.samplers import TPESampler | |
from lightning import Trainer | |
import argparse | |
import os | |
from datetime import datetime | |
from typing import Any | |
import torch | |
from hydra.utils import instantiate | |
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor | |
from lightning.pytorch.loggers import CSVLogger, WandbLogger | |
from omegaconf import OmegaConf | |
def create_experiment_dir(config: dict[str, Any]) -> str: | |
"""Create experiment directory. | |
Args: | |
config: config file | |
Returns: | |
config with updated save_dir | |
""" | |
os.makedirs(config["experiment"]["exp_dir"], exist_ok=True) | |
exp_dir_name = ( | |
f"{config['experiment']['experiment_name']}" | |
f"_{datetime.now().strftime('%m-%d-%Y_%H-%M-%S-%f')}" | |
) | |
config["experiment"]["experiment_name"] = exp_dir_name | |
exp_dir_path = os.path.join(config["experiment"]["exp_dir"], exp_dir_name) | |
os.makedirs(exp_dir_path) | |
config["experiment"]["save_dir"] = exp_dir_path | |
config["trainer"]["default_root_dir"] = exp_dir_path | |
return config | |
def generate_trainer(config: dict[str, Any], trial) -> Trainer: | |
"""Generate a pytorch lightning trainer.""" | |
loggers = [ | |
CSVLogger(config["experiment"]["save_dir"], name="csv_logs"), | |
WandbLogger( | |
name=config["experiment"]["experiment_name"], | |
save_dir=config["experiment"]["save_dir"], | |
project=config["wandb"]["project"], | |
entity=config["wandb"]["entity"], | |
resume="allow", | |
mode=config["wandb"]["mode"], | |
), | |
] | |
lr_monitor_callback = LearningRateMonitor(logging_interval="step") | |
callbacks = [ | |
lr_monitor_callback, | |
PyTorchLightningPruningCallback(trial, monitor="valAcc"), | |
] | |
return instantiate( | |
config.trainer, | |
default_root_dir=config["experiment"]["save_dir"], | |
callbacks=callbacks, | |
logger=loggers, | |
) | |
def objective(trial: optuna.trial.Trial, config) -> float: | |
# MODEL HYPERPARAMETERS to tune with optuna | |
config.model.model.drop_rate = trial.suggest_float("dropout", 0.1, 0.5) | |
config.model.optimizer.lr = trial.suggest_float( | |
"learning_rate", 1e-5, 1e-1, log=True | |
) | |
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD", "AdamW"]) | |
if optimizer_name == "Adam": | |
config.model.optimizer._target_ = "torch.optim.Adam" | |
elif optimizer_name == "SGD": | |
config.model.optimizer._target_ = "torch.optim.SGD" | |
config.model.optimizer.weight_decay = trial.suggest_float( | |
"weight_decay", 1e-5, 1e-3, log=True | |
) | |
config.model.optimizer.momentum = trial.suggest_float("momentum", 0.8, 0.99) | |
elif optimizer_name == "AdamW": | |
config.model.optimizer._target_ = "torch.optim.AdamW" | |
config.model.optimizer.weight_decay = trial.suggest_float( | |
"weight_decay", 1e-5, 1e-3, log=True | |
) | |
lr_schedule = trial.suggest_categorical( | |
"lr_schedule", ["StepLR", "ExponentialLR", "CosineAnnealingLR", "null"] | |
) | |
if lr_schedule == "StepLR": | |
config.model.lr_scheduler = { | |
"_target_": "torch.optim.lr_scheduler.StepLR", | |
"_partial_": True, | |
"step_size": trial.suggest_int("step_size", 1, 10), | |
"gamma": trial.suggest_float("gamma", 0.1, 0.9), | |
} | |
elif lr_schedule == "ExponentialLR": | |
config.model.lr_scheduler = { | |
"_target_": "torch.optim.lr_scheduler.ExponentialLR", | |
"_partial_": True, | |
"gamma": trial.suggest_float("gamma", 0.1, 0.9), | |
} | |
elif lr_schedule == "CosineAnnealingLR": | |
config.model.lr_scheduler = { | |
"_target_": "torch.optim.lr_scheduler.CosineAnnealingLR", | |
"_partial_": True, | |
"T_max": trial.suggest_int("T_max", 10, 50), | |
} | |
elif lr_schedule == "null": | |
config.model.lr_scheduler = None | |
model = instantiate(config.model) | |
model.input_key = "image" | |
model.target_key = "label" | |
# DataModule parameters | |
datamodule = instantiate(config.datamodule) | |
trainer = generate_trainer(config, trial) | |
trainer.fit(model, datamodule=datamodule) | |
return trainer.callback_metrics["valAcc"].item() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="PyTorch Lightning example.") | |
parser.add_argument( | |
"--pruning", | |
"-p", | |
action="store_true", | |
help="Activate the pruning feature. `MedianPruner` stops unpromising " | |
"trials at the early stages of training.", | |
) | |
parser.add_argument( | |
"--config", | |
"-c", | |
) | |
args = parser.parse_args() | |
pruner = ( | |
optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() | |
) | |
config = OmegaConf.load(args.config) | |
config = create_experiment_dir(config) | |
study = optuna.create_study( | |
direction="maximize", pruner=pruner, sampler=TPESampler() | |
) | |
study.optimize(lambda trial: objective(trial, config), n_trials=100, timeout=600) | |
print("Number of finished trials: {}".format(len(study.trials))) | |
print("Best trial:") | |
trial = study.best_trial | |
print(" Value: {}".format(trial.value)) | |
print(" Params: ") | |
for key, value in trial.params.items(): | |
print(" {}: {}".format(key, value)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment