Skip to content

Instantly share code, notes, and snippets.

@JanSchm
Created June 25, 2020 13:26
Show Gist options
  • Save JanSchm/5b41b613337a65004fc410701bbe02e0 to your computer and use it in GitHub Desktop.
Save JanSchm/5b41b613337a65004fc410701bbe02e0 to your computer and use it in GitHub Desktop.
class MultiAttention(Layer):
def __init__(self, d_k, d_v, n_heads):
super(MultiAttention, self).__init__()
self.d_k = d_k
self.d_v = d_v
self.n_heads = n_heads
self.attn_heads = list()
def build(self, input_shape):
for n in range(self.n_heads):
self.attn_heads.append(SingleAttention(self.d_k, self.d_v))
self.linear = Dense(7, input_shape=input_shape, kernel_initializer='glorot_uniform', bias_initializer='glorot_uniform')
def call(self, inputs):
attn = [self.attn_heads[i](inputs) for i in range(self.n_heads)]
concat_attn = tf.concat(attn, axis=-1)
multi_linear = self.linear(concat_attn)
return multi_linear
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment