Last active
June 22, 2024 08:00
-
-
Save vankesteren/96207abcd16ecd01a2491bcbec12c73f to your computer and use it in GitHub Desktop.
Julia implementation of Adam optimizer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Adamopt | |
# This is a module implementing vanilla Adam (https://arxiv.org/abs/1412.6980). | |
export Adam, step! | |
# Struct containing all necessary info | |
mutable struct Adam | |
theta::AbstractArray{Float64} # Parameter array | |
loss::Function # Loss function | |
grad::Function # Gradient function | |
m::AbstractArray{Float64} # First moment | |
v::AbstractArray{Float64} # Second moment | |
b1::Float64 # Exp. decay first moment | |
b2::Float64 # Exp. decay second moment | |
a::Float64 # Step size | |
eps::Float64 # Epsilon for stability | |
t::Int # Time step (iteration) | |
end | |
# Outer constructor | |
function Adam(theta::AbstractArray{Float64}, loss::Function, grad::Function) | |
m = zeros(size(theta)) | |
v = zeros(size(theta)) | |
b1 = 0.9 | |
b2 = 0.999 | |
a = 0.001 | |
eps = 1e-8 | |
t = 0 | |
Adam(theta, loss, grad, m, v, b1, b2, a, eps, t) | |
end | |
# Step function with optional keyword arguments for the data passed to grad() | |
function step!(opt::Adam; data...) | |
opt.t += 1 | |
gt = opt.grad(opt.theta; data...) | |
opt.m = opt.b1 .* opt.m + (1 - opt.b1) .* gt | |
opt.v = opt.b2 .* opt.v + (1 - opt.b2) .* gt .^ 2 | |
mhat = opt.m ./ (1 - opt.b1^opt.t) | |
vhat = opt.v ./ (1 - opt.b2^opt.t) | |
opt.theta -= opt.a .* (mhat ./ (sqrt.(vhat) .+ opt.eps)) | |
end | |
end |
Author
vankesteren
commented
Nov 21, 2019
•
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment