Last active
February 15, 2023 19:48
-
-
Save shawwn/97811b6819a444ce92187532743a920e to your computer and use it in GitHub Desktop.
Reformulating Adam optimizer to gain an intuition about what it's doing.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def lerp(a, b, t): | |
return (b - a) * t + a | |
def bias(i, x, beta): | |
return 1 - jnp.asarray(beta, x.dtype) ** (i + 1) | |
@optimizer | |
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8) -> OptimizerResult: | |
"""Construct optimizer triple for Adam. | |
Args: | |
step_size: positive scalar, or a callable representing a step size schedule | |
that maps the iteration index to a positive scalar. | |
b1: optional, a positive scalar value for beta_1, the exponential decay rate | |
for the first moment estimates (default 0.9). | |
b2: optional, a positive scalar value for beta_2, the exponential decay rate | |
for the second moment estimates (default 0.999). | |
eps: optional, a positive scalar value for epsilon, a small constant for | |
numerical stability (default 1e-8). | |
Returns: | |
An (init_fun, update_fun, get_params) triple. | |
""" | |
step_size = make_schedule(step_size) | |
def init(x0): | |
m0 = jnp.zeros_like(x0) | |
M0 = jnp.zeros_like(x0) | |
return x0, m0, M0 | |
def update(i, g, state): | |
# Think of the gradient like a force (in the Newtonian physics sense). | |
# It accelerates each weight in the direction of the gradient. | |
# The larger the gradient, the faster the acceleration. | |
# | |
# Per-weight state: | |
# | |
# Lowercase letters represent values. | |
# Uppercase letters represent squared values. | |
# | |
# x is the position (i.e. the current weight value). | |
# | |
# g is the gradient. | |
# G is the gradient squared. | |
# | |
# m is a measurement of the gradient over time (i.e. a moving average of the gradient). | |
# M is a measurement of the squared gradient over time (i.e. a moving average of the squared gradient). | |
# | |
# b1 controls how quickly m accelerates towards g. Defaults to 0.9. | |
# b2 contorls how quickly M accelerates towards G. Defaults to 0.999. | |
# | |
x, m, M = state | |
G = jnp.square(g) | |
# Accelerate each weight by pushing each weight's velocity along its gradient vector. | |
m = lerp(g, m, b1) # Push the velocity (m) toward its gradient (i.e. first moment estimate). | |
M = lerp(G, M, b2) # Push the squared velocity (M) toward its squared gradient (i.e. second moment estimate). | |
m_ = m / bias(i, m, b1) # Bias correction, since the moving averages start at zero. | |
M_ = M / bias(i, M, b2) | |
# Calculate each weight's new velocity by measuring the gradient (m) and squared gradient (M) over time. | |
# Velocity is a change in position (dx) over a change in time (dt). | |
dx = m_ / (jnp.sqrt(M_) + eps) # A change in position (dx). | |
dt = -step_size(i) # A change in time (dt). | |
# Calculate each weight's new position by pushing each weight's position along its velocity vector. | |
# The position offset is calculated by multiplying the change in position by the change in time (dx * dt). | |
# Since it's an offset, we can just add it to the old position to get the new position. | |
x = x + dx * dt | |
# We're done; return the new state. | |
return x, m, M | |
def get_params(state): | |
x, _, _ = state | |
return x | |
return init, update, get_params | |
# A second revision, with fewer comments but more descriptive variable names: | |
@optimizer | |
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8) -> OptimizerResult: | |
"""Construct optimizer triple for Adam. | |
Args: | |
step_size: positive scalar, or a callable representing a step size schedule | |
that maps the iteration index to a positive scalar. | |
b1: optional, a positive scalar value for beta_1, the exponential decay rate | |
for the first moment estimates (default 0.9). | |
b2: optional, a positive scalar value for beta_2, the exponential decay rate | |
for the second moment estimates (default 0.999). | |
eps: optional, a positive scalar value for epsilon, a small constant for | |
numerical stability (default 1e-8). | |
Returns: | |
An (init_fun, update_fun, get_params) triple. | |
""" | |
step_size = make_schedule(step_size) | |
def init(x0): | |
m0 = jnp.zeros_like(x0) | |
M0 = jnp.zeros_like(x0) | |
return x0, m0, M0 | |
def update(i, g, state): | |
x, m, M = state | |
G = jnp.square(g) | |
# Calculate acceleration. | |
m = lerp(g, m, b1) # First moment estimate. | |
M = lerp(G, M, b2) # Second moment estimate. | |
m_ = m / bias(i, m, b1) | |
M_ = M / bias(i, M, b2) | |
# m_ is velocity (a vector) | |
# M_ is squared speed (a directionless quantity) | |
# sqrt(M_) is average speed over time | |
velocity = m_ | |
speed = jnp.sqrt(M_) | |
# Divide velocity by speed to get a normalized direction. | |
normal = velocity / (speed + eps) | |
# Push the weights in the direction of the (normalized) gradient. | |
scale = -step_size(i) | |
offset = normal * scale | |
x = x + offset | |
# Return the new state. | |
return x, m, M | |
def get_params(state): | |
x, _, _ = state | |
return x | |
return init, update, get_params | |
# The original Adam code, for comparison. Identical to the above two versions; it's what I started with. | |
@optimizer | |
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8) -> OptimizerResult: | |
"""Construct optimizer triple for Adam. | |
Args: | |
step_size: positive scalar, or a callable representing a step size schedule | |
that maps the iteration index to a positive scalar. | |
b1: optional, a positive scalar value for beta_1, the exponential decay rate | |
for the first moment estimates (default 0.9). | |
b2: optional, a positive scalar value for beta_2, the exponential decay rate | |
for the second moment estimates (default 0.999). | |
eps: optional, a positive scalar value for epsilon, a small constant for | |
numerical stability (default 1e-8). | |
Returns: | |
An (init_fun, update_fun, get_params) triple. | |
""" | |
step_size = make_schedule(step_size) | |
def init(x0): | |
m0 = jnp.zeros_like(x0) | |
v0 = jnp.zeros_like(x0) | |
return x0, m0, v0 | |
def update(i, g, state): | |
x, m, v = state | |
m = (1 - b1) * g + b1 * m # First moment estimate. | |
v = (1 - b2) * jnp.square(g) + b2 * v # Second moment estimate. | |
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. | |
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) | |
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps) | |
return x, m, v | |
def get_params(state): | |
x, _, _ = state | |
return x | |
return init, update, get_params | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment