Last active
November 27, 2023 04:39
-
-
Save sgugger/1eebc2bd10fb90e44fa4082f15ff0545 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim.optimizer import Optimizer, required | |
class LARS(Optimizer): | |
def __init__(self, params, lr=required, momentum=0, dampening=0, | |
weight_decay=0, nesterov=False, eta=0.001): | |
if lr is not required and lr < 0.0: | |
raise ValueError("Invalid learning rate: {}".format(lr)) | |
if momentum < 0.0: | |
raise ValueError("Invalid momentum value: {}".format(momentum)) | |
if weight_decay < 0.0: | |
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | |
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, | |
weight_decay=weight_decay, nesterov=nesterov, eta=eta) | |
if nesterov and (momentum <= 0 or dampening != 0): | |
raise ValueError("Nesterov momentum requires a momentum and zero dampening") | |
super().__init__(params, defaults) | |
def __setstate__(self, state): | |
super().__setstate__(state) | |
for group in self.param_groups: | |
group.setdefault('nesterov', False) | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Arguments: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
weight_decay = group['weight_decay'] | |
momentum = group['momentum'] | |
dampening = group['dampening'] | |
nesterov = group['nesterov'] | |
eta = group['eta'] | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
d_p = p.grad.data | |
d_pn = d_p.norm()#new | |
if weight_decay != 0: | |
d_p.add_(weight_decay, p.data) | |
d_pn.add_(weight_decay, p.data.norm())#new | |
if momentum != 0: | |
param_state = self.state[p] | |
if 'momentum_buffer' not in param_state: | |
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) | |
buf.mul_(momentum).add_(d_p) | |
else: | |
buf = param_state['momentum_buffer'] | |
buf.mul_(momentum).add_(1 - dampening, d_p) | |
if nesterov: | |
d_p = d_p.add(momentum, buf) | |
else: | |
d_p = buf | |
rho = eta * p.data.norm() / (1e-15 + d_pn)#new | |
p.data.add_(-group['lr'] * rho, d_p)#changed | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment