Created
August 12, 2020 03:16
-
-
Save mvoelk/cc4fe8d40a0d964919d884797799c4d2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras.backend as K | |
from keras.optimizers import Optimizer | |
from keras.legacy import interfaces | |
class WNAdam(Optimizer): | |
"""WNAdam optimizer. | |
Default parameters follow those provided in the original paper. | |
# Arguments | |
lr: float >= 0. Learning rate. | |
beta_1: float, 0 < beta < 1. Generally close to 1. | |
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. | |
# References | |
- WNGrad paper - https://arxiv.org/abs/1803.02865 | |
""" | |
def __init__(self, lr=0.1, beta_1=0.9, **kwargs): | |
super(WNAdam, self).__init__(**kwargs) | |
with K.name_scope(self.__class__.__name__): | |
self.iterations = K.variable(0, dtype='int64', name='iterations') | |
self.lr = K.variable(lr, name='lr') | |
self.beta_1 = K.variable(beta_1, name='beta_1') | |
@interfaces.legacy_get_updates_support | |
def get_updates(self, loss, params): | |
grads = self.get_gradients(loss, params) | |
self.updates = [K.update_add(self.iterations, 1)] | |
lr = self.lr | |
t = K.cast(self.iterations, K.floatx()) + 1 | |
# Algorithm 4 initializations: | |
# momentum accumulator is initialized with 0s | |
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] | |
# b parameter is initialized with 1s | |
bs = [K.ones(K.int_shape(p), dtype=K.dtype(p)) for p in params] | |
self.weights = [self.iterations] + ms + bs | |
for p, g, m, b in zip(params, grads, ms, bs): | |
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g | |
b_t = b + K.square(lr) * K.square(g) / b | |
# note: paper has the K.pow as t - 1, but this nans out when t = 1 | |
p_t = p - (lr / b_t) * m_t / (1 - K.pow(self.beta_1, t)) | |
self.updates.append(K.update(m, m_t)) | |
self.updates.append(K.update(b, b_t)) | |
new_p = p_t | |
# Apply constraints. | |
if getattr(p, 'constraint', None) is not None: | |
new_p = p.constraint(new_p) | |
self.updates.append(K.update(p, new_p)) | |
return self.updates | |
def get_config(self): | |
config = {'lr': float(K.get_value(self.lr)), | |
'beta_1': float(K.get_value(self.beta_1))} | |
base_config = super(WNAdam, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
''' | |
Test it on a poorly conditioned problem a la Ali Rahimi's problem. | |
''' | |
import numpy as np | |
from keras.models import Model | |
from keras.layers import Dense, Input | |
def make_data(xdim=6, ydim=10, nsamples=1000, | |
A_condition_number=1e-5): | |
# the true map from x to y | |
Atrue = np.linspace(1, A_condition_number, ydim | |
).reshape(-1, 1) * np.random.rand(ydim, xdim) | |
# the inputs | |
X = np.random.randn(xdim, nsamples) | |
# the y's to fit | |
Y = Atrue.dot(X) | |
return X.T, Y.T | |
def make_model(xdim, lr): | |
main_input = Input(shape=(xdim,)) | |
W2x = Dense(wdim, use_bias=False)(main_input) | |
W1W2x = Dense(ydim, use_bias=False)(W2x) | |
model = Model(inputs=main_input, outputs=W1W2x) | |
model.compile(optimizer=WNAdam(lr), loss='mse') | |
return model | |
''' | |
main | |
''' | |
# data | |
ydim = 10 | |
xdim = 6 | |
wdim = 6 | |
lr = 10 | |
X, Y = make_data(xdim, ydim) | |
model1 = make_model(xdim, lr) | |
history = model1.fit(X, Y, batch_size=16, epochs=50) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment