Skip to content

Instantly share code, notes, and snippets.

@mschauer
Forked from vankesteren/Adamopt.jl
Created February 16, 2022 20:02
Show Gist options
  • Save mschauer/63dcb74aea1fdab42441a859798062b1 to your computer and use it in GitHub Desktop.
Save mschauer/63dcb74aea1fdab42441a859798062b1 to your computer and use it in GitHub Desktop.
Julia implementation of Adam optimizer
module Adamopt
# This is a module implementing vanilla Adam (https://arxiv.org/abs/1412.6980).
export Adam, step!
# Struct containing all necessary info
mutable struct Adam
theta::AbstractArray{Float64} # Parameter array
loss::Function # Loss function
grad::Function # Gradient function
m::AbstractArray{Float64} # First moment
v::AbstractArray{Float64} # Second moment
b1::Float64 # Exp. decay first moment
b2::Float64 # Exp. decay second moment
a::Float64 # Step size
eps::Float64 # Epsilon for stability
t::Int # Time step (iteration)
end
# Outer constructor
function Adam(theta::AbstractArray{Float64}, loss::Function, grad::Function)
m = zeros(size(theta))
v = zeros(size(theta))
b1 = 0.9
b2 = 0.999
a = 0.001
eps = 1e-8
t = 0
Adam(theta, loss, grad, m, v, b1, b2, a, eps, t)
end
# Step function with optional keyword arguments for the data passed to grad()
function step!(opt::Adam; data...)
opt.t += 1
gt = opt.grad(opt.theta; data...)
opt.m = opt.b1 .* opt.m + (1 - opt.b1) .* gt
opt.v = opt.b2 .* opt.v + (1 - opt.b2) .* gt .^ 2
mhat = opt.m ./ (1 - opt.b1^opt.t)
vhat = opt.v ./ (1 - opt.b2^opt.t)
opt.theta -= opt.a .* (mhat ./ (sqrt.(vhat) .+ opt.eps))
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment