Skip to content

Instantly share code, notes, and snippets.

@colllin
Last active June 4, 2023 11:00
Learning Rate Finder in PyTorch

Notes

  • You'll need to bring your own functions which initialize a fresh optimizer, dataloaders, return a loss function, etc.
  • You'll want to look through the cell which runs the LR finder and consider adjusting...
    • which parameters are tested, i.e. beta1 and wd might not be valid arguments for your optimizer. I recommend only varying one or two parameters at a time.
    • which parameters values are tested, i.e. beta1 in [0.85, 0.95] and wd=0.1. The LR finder is run 3 times for each combination of parameter values, so I recommend restricting to 4 combinations at a time, and repeating as necessary.
    • which range of learning rates is tested, i.e. start_lr=1e-6 and end_lr=1e-3. I recommend starting with a longer range for a small initial test, e.g. 1e-6 to 1e0, and then adjusting to the useful range for further tests.
    • how many steps are taken across this range, i.e. steps=100. I recommend roughly 50 steps per order of magnitude, but in general fewer steps will run faster, so choose the lowest value which gives you useful results.

TODO

  • Remove ignite dependency
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
# See https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
import torch
import ignite
import numpy as np
import pandas as pd
# Usage:
# orig_state = serialize_state(model, optimizer)
# def serialize_state(model, optimizer):
# return {
# 'model_training': model.training,
# 'model_state': model.state_dict(),
# 'optim_state': optimizer.state_dict(),
# }
# Usage:
# restore_state(model, optimizer, orig_state)
# def restore_state(model, optimizer, state):
# model.train(state['model_training'])
# model.load_state_dict(state['model_state'])
# optimizer.load_state_dict(state['optim_state'])
def update_lrs(optimizer, lrs):
lrs = np.broadcast_to(lrs, len(optimizer.param_groups))
for group, lr in zip(optimizer.param_groups, lrs):
# for p in group['params']:
group['lr'] = lr
def smooth_curve(vals, beta):
avg_val = 0
smoothed = []
for (i,v) in enumerate(vals):
avg_val = beta * avg_val + (1-beta) * v
smoothed.append(avg_val/(1-beta**(i+1)))
return smoothed
def find_lr(model, dataloader, optimizer, loss_fn, start_lr=1e-5, end_lr=10, steps='auto', linear=False, beta=0.98,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), **kwargs):
"""Helps you find an optimal learning rate for a model.
It uses the technique developed in the 2015 paper
`Cyclical Learning Rates for Training Neural Networks`, where
we simply keep increasing the learning rate from a very small value,
until the loss starts decreasing.
Args:
start_lr (float/numpy array) : Passing in a numpy array allows you
to specify learning rates for a learner's layer_groups
end_lr (float) : The maximum learning rate to try.
steps (int, optional): How many steps to take while incrementing the LR. Defaults to the length of the dataloader.
Examples:
As training moves us closer to the optimal weights for a model,
the optimal learning rate will be smaller. We can take advantage of
that knowledge and provide lr_find() with a starting learning rate
1000x smaller than the model's current learning rate as such:
>> learn.lr_find(lr/1000)
>> lrs = np.array([ 1e-4, 1e-3, 1e-2 ])
>> learn.lr_find(lrs / 1000)
Notes:
lr_find() may finish before going through every batch of examples if
the loss "explodes" enough.
.. _Cyclical Learning Rates for Training Neural Networks:
http://arxiv.org/abs/1506.01186
"""
trainer = ignite.engine.create_supervised_trainer(model, optimizer, loss_fn, device=device)
num_batches = steps if type(steps) == int else len(dataloader)
lrs = {}
if linear:
lrs['queue'] = np.linspace(start_lr, end_lr, num=num_batches)
else:
lrs['queue'] = np.logspace(np.log10(start_lr), np.log10(end_lr), num=num_batches)
lrs['current'] = None
lrs['history'] = pd.DataFrame([], columns=['lr', 'loss'])
def step_lr(optimizer):
lrs['current'], lrs['queue'] = lrs['queue'][0], lrs['queue'][1:]
update_lrs(optimizer, lrs['current'])
def record_lr_loss(loss):
record = {}
record['lr'] = lrs['current']
record['loss'] = loss
# prev_moving_avg = lrs['history'].tail(1)['loss_moving_avg'].tolist()[-1] if len(lrs['history']) > 0 else 0
# record['loss_moving_avg'] = beta * prev_moving_avg + (1-beta) * record['loss']
# batch_num = len(lrs['history']) + 1
# record['loss_smoothed'] = record['loss_moving_avg'] / (1 - beta**batch_num)
lrs['history'] = lrs['history'].append(record, ignore_index=True)
def terminate_on_loss_explosion(trainer):
smoothed = smooth_curve(lrs['history']['loss'].tolist(), beta)
if smoothed[-1] > 4*np.array(smoothed).min():
print(f'Terminating: Loss is exploding ({smoothed[-1]} > 4 * {np.array(smoothed).min()}).')
trainer.terminate()
def terminate_on_empty_queue(trainer):
if len(lrs['queue']) == 0:
print(f'Terminating: Reached end of dataloader or max batches.')
trainer.terminate()
trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, lambda trainer: step_lr(optimizer))
trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, lambda trainer: record_lr_loss(trainer.state.output))
trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, lambda trainer: terminate_on_loss_explosion(trainer))
trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, lambda trainer: terminate_on_empty_queue(trainer))
from .bind_epoch_tqdm import bind_epoch_tqdm
bind_epoch_tqdm(trainer, desc_fn=lambda trainer: f"lr={lrs['history'].tail(1)['lr'].tolist()[-1]:.3E} loss={lrs['history'].tail(1)['loss'].tolist()[-1]:.3f}")
trainer.run(dataloader, max_epochs=10)
lrs['history']['loss_smoothed'] = smooth_curve(lrs['history']['loss'].tolist(), beta)
return lrs['history']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment