Skip to content

Instantly share code, notes, and snippets.

@acl21
Created September 27, 2019 02:38
Show Gist options
  • Save acl21/3e71544463c9739226a2afe42edef665 to your computer and use it in GitHub Desktop.
Save acl21/3e71544463c9739226a2afe42edef665 to your computer and use it in GitHub Desktop.
Adam
def do_adam():
w, b, eta, max_epochs = 1, 1, 0.01, 100,
m_w, m_b, v_w, v_b, eps, beta1, beta2 = 0, 0, 0, 0, 1e-8, 0.9, 0.99
for i in range(max_epochs):
dw, db = 0, 0
for x,y in data:
dw += grad_w(w, b, x, y)
db += grad_b(w, b, x, y)
m_w = beta1 * m_w + (1-beta1) * dw
m_b = beta1 * m_b + (1-beta1) * db
v_w = beta2 * v_w + (1-beta2) * dw**2
v_b = beta2 * v_b + (1-beta2) * db**2
m_w = m_w/(1-beta1**(i+1))
m_b = m_b/(1-beta1**(i+1))
v_w = v_w/(1-beta2**(i+1))
v_b = v_b/(1-beta2**(i+1))
w = w - eta * m_w/np.sqrt(v_w + eps)
b = b - eta * m_b/np.sqrt(v_b + eps)
print(error(w,b))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment