Skip to content

Instantly share code, notes, and snippets.

@antorsae
Last active May 6, 2020 08:54
Show Gist options
  • Save antorsae/2328bf9306e36725442c1840be6f3973 to your computer and use it in GitHub Desktop.
Save antorsae/2328bf9306e36725442c1840be6f3973 to your computer and use it in GitHub Desktop.
# Keras layer implementation of "Fix your classifier: the marginal value of training the last weight layer"
# by Andres Torrubia, licensed under GPL 3: https://www.gnu.org/licenses/gpl-3.0.en.html
# https://arxiv.org/abs/1801.04540
from keras import backend as K
from keras.engine.topology import Layer
from keras import activations
from keras.initializers import Constant, RandomUniform
import numpy as np
from scipy.linalg import hadamard
import math
class HadamardClassifier(Layer):
def __init__(self, output_dim, activation=None, use_bias=True,
l2_normalize=True, output_raw_logits=False, **kwargs):
self.output_dim = output_dim
self.activation = activations.get(activation)
self.use_bias = use_bias
self.l2_normalize = l2_normalize
self.output_raw_logits = output_raw_logits
super(HadamardClassifier, self).__init__(**kwargs)
def build(self, input_shape):
hadamard_size = 2 ** int(math.ceil(math.log(max(input_shape[1], self.output_dim), 2)))
self.hadamard = K.constant(
value=hadamard(hadamard_size, dtype=np.int8)[:input_shape[1], :self.output_dim])
init_scale = 1. / math.sqrt(self.output_dim)
self.scale = self.add_weight(name='scale',
shape=(1,),
initializer=Constant(init_scale),
trainable=True)
if self.use_bias:
self.bias = self.add_weight(name='bias',
shape=(self.output_dim,),
initializer=RandomUniform(-init_scale, init_scale),
trainable=True)
super(HadamardClassifier, self).build(input_shape)
def call(self, x, training=None):
is_training = training not in {0, False}
output = K.l2_normalize(x, axis=-1) if self.l2_normalize else x
output = -self.scale * K.dot(output, self.hadamard) # pity .dot requires both tensors to be same type, the last one could be int8
if self.output_raw_logits:
output_logits = -self.scale * K.dot(x, self.hadamard) # probably better to reuse output * l2norm
if self.use_bias:
output = K.bias_add(output, self.bias)
if self.output_raw_logits:
output_logits = K.bias_add(output_logits, self.bias)
if self.activation is not None:
output = self.activation(output)
if self.output_raw_logits:
return [output, output_logits]
return output
def compute_output_shape(self, input_shape):
if self.output_raw_logits:
return [(input_shape[0], self.output_dim)] * 2
return (input_shape[0], self.output_dim)
def get_config(self):
config = {
'output_dim': self.output_dim,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'l2_normalize': self.l2_normalize,
'output_raw_logits' : self.output_raw_logits,
}
base_config = super(HadamardClassifier, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@rabahFr
Copy link

rabahFr commented May 6, 2020

Hello,
is this up to date with the latest Keras version?
Thank you.

@antorsae
Copy link
Author

antorsae commented May 6, 2020

No. I've since moved to pytorch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment