Created
January 22, 2025 04:24
-
-
Save ejmejm/d71eeaeaf8787c12688e3acbf58fea7a to your computer and use it in GitHub Desktop.
Autostep
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch.optim.optimizer import Optimizer | |
from typing import Iterator, Optional | |
class IDBD(Optimizer): | |
"""Incremental Delta-Bar-Delta optimizer. | |
This is an implementation of the IDBD algorithm adapted for deep neural networks. | |
Instead of working with input features directly, it uses gradients with respect | |
to parameters and maintains separate learning rates for each parameter. | |
Args: | |
params: Iterable of parameters to optimize | |
meta_lr: Meta learning rate (default: 0.01) | |
init_lr: Initial learning rate (default: 0.01) | |
""" | |
def __init__( | |
self, | |
params: Iterator[torch.Tensor], | |
meta_lr: float = 0.01, | |
init_lr: float = 0.01, | |
tau: float = 1e4, | |
weight_decay: float = 0.0, | |
autostep: bool = True, | |
): | |
# Convert iterator to list so we can check all params | |
param_list = list(params) | |
# Check that parameters match a linear layer structure | |
weights = [p for p in param_list if len(p.shape) == 2] | |
biases = [p for p in param_list if len(p.shape) == 1] | |
assert len(weights) == 1, 'IDBD optimizer expects exactly one weight matrix (2D tensor)' | |
assert len(biases) <= 1, 'IDBD optimizer expects at most one bias vector (1D tensor)' | |
if biases: | |
assert biases[0].shape[0] == weights[0].shape[0], 'Weight matrix and bias vector dimensions do not match' | |
defaults = dict(meta_lr=meta_lr, tau=tau) | |
super().__init__(param_list, defaults) | |
self.weight_decay = weight_decay | |
self.autostep = autostep | |
# Initialize beta and h for each parameter | |
for group in self.param_groups: | |
for p in group['params']: | |
state = self.state[p] | |
state['step_size'] = torch.full_like(p.data, init_lr) | |
state['h'] = torch.zeros_like(p.data) | |
state['v'] = torch.zeros_like(p.data) | |
@torch.no_grad() | |
def step(self, inputs: torch.Tensor, closure: Optional[callable] = None) -> Optional[float]: | |
"""Performs a single optimization step. | |
Args: | |
inputs: Input tensor to the linear model (a linear model is assumed by this optimizer) | |
closure: A closure that reevaluates the model and returns the loss | |
Returns: | |
Optional computed loss from closure | |
""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
# If I want to do a faithful deep IDBD implementation, then the right-hand-side of equation 12 | |
# needs to be alpha * second order derivative of the loss with respect to beta | |
param_updates = [] | |
for group in self.param_groups: | |
meta_lr = group['meta_lr'] | |
tau = group['tau'] | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
# For biases, use 1s for the inputs | |
if len(p.grad.shape) == 1: | |
inputs = torch.ones_like(p.grad) | |
# For weights, use the actual inputs | |
elif len(p.grad.shape) == 2: | |
inputs = inputs.unsqueeze(0) | |
else: | |
raise ValueError(f"Invalid gradient shape: {p.grad.shape}") | |
squared_inputs = inputs.pow(2) | |
grad = p.grad | |
state = self.state[p] | |
# Get state variables | |
step_size = state['step_size'] | |
h = state['h'] | |
v = state['v'] | |
# Calculate and update step-size (learning rate / alpha) | |
if self.autostep: | |
v = torch.max( | |
torch.abs(grad * h), | |
v + 1.0 / tau * step_size * squared_inputs * (torch.abs(grad * h) - v), | |
) | |
new_step_size = step_size * torch.exp(meta_lr * grad * h / v) | |
step_size = torch.where( | |
v != 0, | |
new_step_size, | |
step_size, | |
) | |
# Normalize the step-size | |
effective_step_size = torch.clamp(torch.sum(step_size * squared_inputs, dim=1), min=1.0) | |
step_size = step_size / effective_step_size.unsqueeze(1) | |
else: | |
step_size = torch.exp(meta_lr * grad * h) | |
# Queue paramter update | |
weight_decay_term = self.weight_decay * p.data if self.weight_decay != 0 else 0 | |
param_update = -step_size * (grad + weight_decay_term) | |
param_updates.append((p, param_update)) | |
# Update h (activation trace) | |
h = h * (1 - step_size * squared_inputs).clamp(min=0) + step_size * grad | |
# Store updated states | |
state['step_size'] = step_size | |
state['h'] = h | |
state['v'] = v | |
for p, param_update in param_updates: | |
p.add_(param_update) | |
p.grad = None | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment