Skip to content

Instantly share code, notes, and snippets.

@ejmejm
Created January 22, 2025 04:24
Show Gist options
  • Save ejmejm/d71eeaeaf8787c12688e3acbf58fea7a to your computer and use it in GitHub Desktop.
Save ejmejm/d71eeaeaf8787c12688e3acbf58fea7a to your computer and use it in GitHub Desktop.
Autostep
import math
import torch
from torch.optim.optimizer import Optimizer
from typing import Iterator, Optional
class IDBD(Optimizer):
"""Incremental Delta-Bar-Delta optimizer.
This is an implementation of the IDBD algorithm adapted for deep neural networks.
Instead of working with input features directly, it uses gradients with respect
to parameters and maintains separate learning rates for each parameter.
Args:
params: Iterable of parameters to optimize
meta_lr: Meta learning rate (default: 0.01)
init_lr: Initial learning rate (default: 0.01)
"""
def __init__(
self,
params: Iterator[torch.Tensor],
meta_lr: float = 0.01,
init_lr: float = 0.01,
tau: float = 1e4,
weight_decay: float = 0.0,
autostep: bool = True,
):
# Convert iterator to list so we can check all params
param_list = list(params)
# Check that parameters match a linear layer structure
weights = [p for p in param_list if len(p.shape) == 2]
biases = [p for p in param_list if len(p.shape) == 1]
assert len(weights) == 1, 'IDBD optimizer expects exactly one weight matrix (2D tensor)'
assert len(biases) <= 1, 'IDBD optimizer expects at most one bias vector (1D tensor)'
if biases:
assert biases[0].shape[0] == weights[0].shape[0], 'Weight matrix and bias vector dimensions do not match'
defaults = dict(meta_lr=meta_lr, tau=tau)
super().__init__(param_list, defaults)
self.weight_decay = weight_decay
self.autostep = autostep
# Initialize beta and h for each parameter
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step_size'] = torch.full_like(p.data, init_lr)
state['h'] = torch.zeros_like(p.data)
state['v'] = torch.zeros_like(p.data)
@torch.no_grad()
def step(self, inputs: torch.Tensor, closure: Optional[callable] = None) -> Optional[float]:
"""Performs a single optimization step.
Args:
inputs: Input tensor to the linear model (a linear model is assumed by this optimizer)
closure: A closure that reevaluates the model and returns the loss
Returns:
Optional computed loss from closure
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# If I want to do a faithful deep IDBD implementation, then the right-hand-side of equation 12
# needs to be alpha * second order derivative of the loss with respect to beta
param_updates = []
for group in self.param_groups:
meta_lr = group['meta_lr']
tau = group['tau']
for p in group['params']:
if p.grad is None:
continue
# For biases, use 1s for the inputs
if len(p.grad.shape) == 1:
inputs = torch.ones_like(p.grad)
# For weights, use the actual inputs
elif len(p.grad.shape) == 2:
inputs = inputs.unsqueeze(0)
else:
raise ValueError(f"Invalid gradient shape: {p.grad.shape}")
squared_inputs = inputs.pow(2)
grad = p.grad
state = self.state[p]
# Get state variables
step_size = state['step_size']
h = state['h']
v = state['v']
# Calculate and update step-size (learning rate / alpha)
if self.autostep:
v = torch.max(
torch.abs(grad * h),
v + 1.0 / tau * step_size * squared_inputs * (torch.abs(grad * h) - v),
)
new_step_size = step_size * torch.exp(meta_lr * grad * h / v)
step_size = torch.where(
v != 0,
new_step_size,
step_size,
)
# Normalize the step-size
effective_step_size = torch.clamp(torch.sum(step_size * squared_inputs, dim=1), min=1.0)
step_size = step_size / effective_step_size.unsqueeze(1)
else:
step_size = torch.exp(meta_lr * grad * h)
# Queue paramter update
weight_decay_term = self.weight_decay * p.data if self.weight_decay != 0 else 0
param_update = -step_size * (grad + weight_decay_term)
param_updates.append((p, param_update))
# Update h (activation trace)
h = h * (1 - step_size * squared_inputs).clamp(min=0) + step_size * grad
# Store updated states
state['step_size'] = step_size
state['h'] = h
state['v'] = v
for p, param_update in param_updates:
p.add_(param_update)
p.grad = None
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment