Skip to content

Instantly share code, notes, and snippets.

@tdiggelm
Last active October 25, 2019 08:10
Show Gist options
  • Save tdiggelm/c57b635e779dfc97cdace543a78fc2d2 to your computer and use it in GitHub Desktop.
Save tdiggelm/c57b635e779dfc97cdace543a78fc2d2 to your computer and use it in GitHub Desktop.
Tensorflow 2.0 compatible version of MultiHeadAttention as in BERT from paper "Attention Is All You Need", see https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf
import tensorflow as tf
class MultiHeadAttention(tf.keras.layers.Layer):
""" Mirrors the implementation from paper 'Attention Is All You Need' and
corresponding source code in
https://github.com/google-research/bert/blob/master/modeling.py. """
def __init__(self,
size_per_head=16,
num_attention_heads=12,
dropout_rate=0,
activation=None,
return_attention_probs=False, **kwargs):
self.activation = activation
self.size_per_head = size_per_head
self.num_attention_heads = num_attention_heads
self.dropout_rate = dropout_rate
self.return_attention_probs = return_attention_probs
super(MultiHeadAttention, self).__init__(**kwargs)
def build(self, input_shape):
self.K_layer = tf.keras.layers.Dense(
self.size_per_head * self.num_attention_heads,
activation=self.activation)
self.V_layer = tf.keras.layers.Dense(
self.size_per_head * self.num_attention_heads,
activation=self.activation)
self.Q_layer = tf.keras.layers.Dense(
self.size_per_head * self.num_attention_heads,
activation=self.activation)
self.dropout_layer = tf.keras.layers.Dropout(rate=self.dropout_rate)
super(MultiHeadAttention, self).build(input_shape)
def call(self, from_tensor, to_tensor=None, attention_mask=None, training=False):
def reshape_to_matrix(input_tensor):
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
ndims = input_tensor.shape.ndims
if ndims < 2:
raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
(input_tensor.shape))
if ndims == 2:
return input_tensor
width = input_tensor.shape[-1]
output_tensor = tf.reshape(input_tensor, [-1, width])
return output_tensor
def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
seq_length, width):
output_tensor = tf.reshape(
input_tensor, [batch_size, seq_length, num_attention_heads, width])
output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
return output_tensor
if to_tensor is None:
to_tensor = from_tensor
from_shape = tf.shape(from_tensor)
to_shape = tf.shape(to_tensor)
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_seq_length = to_shape[1]
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
num_attention_heads = self.num_attention_heads
size_per_head = self.size_per_head
from_tensor_2d = reshape_to_matrix(from_tensor)
to_tensor_2d = reshape_to_matrix(to_tensor)
# `query_layer` = [B*F, N*H]
query_layer = self.Q_layer(from_tensor_2d)
key_layer = self.K_layer(to_tensor_2d)
value_layer = self.Q_layer(to_tensor_2d)
# `query_layer` = [B, N, F, H]
query_layer = transpose_for_scores(query_layer, batch_size,
num_attention_heads, from_seq_length,
size_per_head)
# `key_layer` = [B, N, T, H]
key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
to_seq_length, size_per_head)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# `attention_scores` = [B, N, F, T]
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = tf.multiply(attention_scores,
1.0 / tf.math.sqrt(float(size_per_head)))
if attention_mask is not None:
attention_mask = tf.expand_dims(attention_mask, axis=1)
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
adder = tf.expand_dims(adder, 1)
attention_scores += adder
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = tf.nn.softmax(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout_layer(attention_probs, training=training)
# `value_layer` = [B, T, N, H]
value_layer = tf.reshape(
value_layer,
[batch_size, to_seq_length, num_attention_heads, size_per_head])
# `value_layer` = [B, N, T, H]
value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
# `context_layer` = [B, N, F, H]
context_layer = tf.matmul(attention_probs, value_layer)
# `context_layer` = [B, F, N, H]
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
# `context_layer` = [B, F, N*H]
context_layer = tf.reshape(
context_layer,
[batch_size, from_seq_length, num_attention_heads * size_per_head])
if not self.return_attention_probs:
return context_layer
else:
return context_layer, attention_probs
def compute_output_shape(self, input_shape):
return (input_shape[0], self.params.num_heads * self.params.size_per_head)
class MultiHeadAttentionPooler(MultiHeadAttention):
""" Implements a pooling operation based on MultiHeadAttention
mechanism. """
def __init__(self, pooling='max', **kwargs):
super(MultiHeadAttentionPooler, self).__init__(**kwargs)
self.pooling = pooling
def build(self, input_shape):
if self.pooling == 'average':
self.pool_layer = tf.keras.layers.GlobalAveragePooling1D()
elif self.pooling == 'max':
self.pool_layer = tf.keras.layers.GlobalMaxPooling1D()
elif self.pooling == 'last':
self.pool_layer = tf.keras.layers.Lambda(lambda x: x[:, -1, :])
elif self.pooling == 'first':
self.pool_layer = tf.keras.layers.Lambda(lambda x: x[:, 0, :])
else:
raise ValueError(f'pooling operation {self.pooling} not supported.')
super(MultiHeadAttentionPooler, self).build(input_shape)
def call(self, sequences, attention_mask=None, training=False):
shape = tf.shape(sequences)
pooled = self.pool_layer(sequences)
pooled = tf.expand_dims(pooled, 1)
context = super(MultiHeadAttentionPooler, self).call(pooled, sequences, attention_mask, training=training)
if not self.return_attention_probs:
context = tf.reshape(context, [-1, self.num_attention_heads * self.size_per_head])
return context
else:
context, attention_probs = context
context = tf.reshape(context, [-1, self.num_attention_heads * self.size_per_head])
attention_probs = tf.reshape(attention_probs, [-1, self.num_attention_heads, shape[1]])
return context, attention_probs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment