Skip to content

Instantly share code, notes, and snippets.

@JanSchm
Last active May 28, 2022 07:54
Show Gist options
  • Save JanSchm/b7986ed4809ec768af0a264a162760a8 to your computer and use it in GitHub Desktop.
Save JanSchm/b7986ed4809ec768af0a264a162760a8 to your computer and use it in GitHub Desktop.
class Time2Vector(Layer):
def __init__(self, seq_len, **kwargs):
super(Time2Vector, self).__init__()
self.seq_len = seq_len
def build(self, input_shape):
self.weights_linear = self.add_weight(name='weight_linear',
shape=(int(self.seq_len),),
initializer='uniform',
trainable=True)
self.bias_linear = self.add_weight(name='bias_linear',
shape=(int(self.seq_len),),
initializer='uniform',
trainable=True)
self.weights_periodic = self.add_weight(name='weight_periodic',
shape=(int(self.seq_len),),
initializer='uniform',
trainable=True)
self.bias_periodic = self.add_weight(name='bias_periodic',
shape=(int(self.seq_len),),
initializer='uniform',
trainable=True)
def call(self, x):
x = tf.math.reduce_mean(x[:,:,:4], axis=-1) # Convert (batch, seq_len, 5) to (batch, seq_len)
time_linear = self.weights_linear * x + self.bias_linear
time_linear = tf.expand_dims(time_linear, axis=-1) # (batch, seq_len, 1)
time_periodic = tf.math.sin(tf.multiply(x, self.weights_periodic) + self.bias_periodic)
time_periodic = tf.expand_dims(time_periodic, axis=-1) # (batch, seq_len, 1)
return tf.concat([time_linear, time_periodic], axis=-1) # (batch, seq_len, 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment