Last active
June 22, 2024 08:00
-
-
Save vankesteren/96207abcd16ecd01a2491bcbec12c73f to your computer and use it in GitHub Desktop.
Julia implementation of Adam optimizer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Adamopt | |
# This is a module implementing vanilla Adam (https://arxiv.org/abs/1412.6980). | |
export Adam, step! | |
# Struct containing all necessary info | |
mutable struct Adam | |
theta::AbstractArray{Float64} # Parameter array | |
loss::Function # Loss function | |
grad::Function # Gradient function | |
m::AbstractArray{Float64} # First moment | |
v::AbstractArray{Float64} # Second moment | |
b1::Float64 # Exp. decay first moment | |
b2::Float64 # Exp. decay second moment | |
a::Float64 # Step size | |
eps::Float64 # Epsilon for stability | |
t::Int # Time step (iteration) | |
end | |
# Outer constructor | |
function Adam(theta::AbstractArray{Float64}, loss::Function, grad::Function) | |
m = zeros(size(theta)) | |
v = zeros(size(theta)) | |
b1 = 0.9 | |
b2 = 0.999 | |
a = 0.001 | |
eps = 1e-8 | |
t = 0 | |
Adam(theta, loss, grad, m, v, b1, b2, a, eps, t) | |
end | |
# Step function with optional keyword arguments for the data passed to grad() | |
function step!(opt::Adam; data...) | |
opt.t += 1 | |
gt = opt.grad(opt.theta; data...) | |
opt.m = opt.b1 .* opt.m + (1 - opt.b1) .* gt | |
opt.v = opt.b2 .* opt.v + (1 - opt.b2) .* gt .^ 2 | |
mhat = opt.m ./ (1 - opt.b1^opt.t) | |
vhat = opt.v ./ (1 - opt.b2^opt.t) | |
opt.theta -= opt.a .* (mhat ./ (sqrt.(vhat) .+ opt.eps)) | |
end | |
end |
#### EXAMPLE ####
# let's use this Adam implementation for linear regression
using .Adamopt
using Random
N = 100
x = randn(N, 2)
y = x * [9, -3] + randn(N)
# loss function with data kwargs
function mse(b; x = x, y = y)
res = y - x*b
res'res
end
# gradient function with data kwargs
function grad(b; x = x, y = y)
-2 .* x' * (y - x * b)
end
# deterministic adam
dopt = Adam([0.0, 0.0], mse, grad)
dopt.a = 0.01
for i = 1:5000
step!(dopt)
print(string("Step: ", dopt.t, " | Loss: ", dopt.loss(dopt.theta), "\n"))
end
# stochastic adam
sopt = Adam([0.0, 0.0], mse, grad)
sopt.a = 0.01
batch = 12
epochs = 350
for e = 1:epochs
pidx = Random.randperm(N)
while (length(pidx) > 0)
idx = [ pop!(pidx) for i in 1:batch if length(pidx) > 0 ]
step!(sopt; x = x[idx, :], y = y[idx, :])
end
print(string("Step: ", sopt.t, " | Loss: ", sopt.loss(sopt.theta), "\n"))
end
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
To use the
Adamopt
module