Last active
October 25, 2019 08:10
-
-
Save tdiggelm/c57b635e779dfc97cdace543a78fc2d2 to your computer and use it in GitHub Desktop.
Tensorflow 2.0 compatible version of MultiHeadAttention as in BERT from paper "Attention Is All You Need", see https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class MultiHeadAttention(tf.keras.layers.Layer): | |
""" Mirrors the implementation from paper 'Attention Is All You Need' and | |
corresponding source code in | |
https://github.com/google-research/bert/blob/master/modeling.py. """ | |
def __init__(self, | |
size_per_head=16, | |
num_attention_heads=12, | |
dropout_rate=0, | |
activation=None, | |
return_attention_probs=False, **kwargs): | |
self.activation = activation | |
self.size_per_head = size_per_head | |
self.num_attention_heads = num_attention_heads | |
self.dropout_rate = dropout_rate | |
self.return_attention_probs = return_attention_probs | |
super(MultiHeadAttention, self).__init__(**kwargs) | |
def build(self, input_shape): | |
self.K_layer = tf.keras.layers.Dense( | |
self.size_per_head * self.num_attention_heads, | |
activation=self.activation) | |
self.V_layer = tf.keras.layers.Dense( | |
self.size_per_head * self.num_attention_heads, | |
activation=self.activation) | |
self.Q_layer = tf.keras.layers.Dense( | |
self.size_per_head * self.num_attention_heads, | |
activation=self.activation) | |
self.dropout_layer = tf.keras.layers.Dropout(rate=self.dropout_rate) | |
super(MultiHeadAttention, self).build(input_shape) | |
def call(self, from_tensor, to_tensor=None, attention_mask=None, training=False): | |
def reshape_to_matrix(input_tensor): | |
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" | |
ndims = input_tensor.shape.ndims | |
if ndims < 2: | |
raise ValueError("Input tensor must have at least rank 2. Shape = %s" % | |
(input_tensor.shape)) | |
if ndims == 2: | |
return input_tensor | |
width = input_tensor.shape[-1] | |
output_tensor = tf.reshape(input_tensor, [-1, width]) | |
return output_tensor | |
def transpose_for_scores(input_tensor, batch_size, num_attention_heads, | |
seq_length, width): | |
output_tensor = tf.reshape( | |
input_tensor, [batch_size, seq_length, num_attention_heads, width]) | |
output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) | |
return output_tensor | |
if to_tensor is None: | |
to_tensor = from_tensor | |
from_shape = tf.shape(from_tensor) | |
to_shape = tf.shape(to_tensor) | |
batch_size = from_shape[0] | |
from_seq_length = from_shape[1] | |
to_seq_length = to_shape[1] | |
# Scalar dimensions referenced here: | |
# B = batch size (number of sequences) | |
# F = `from_tensor` sequence length | |
# T = `to_tensor` sequence length | |
# N = `num_attention_heads` | |
# H = `size_per_head` | |
num_attention_heads = self.num_attention_heads | |
size_per_head = self.size_per_head | |
from_tensor_2d = reshape_to_matrix(from_tensor) | |
to_tensor_2d = reshape_to_matrix(to_tensor) | |
# `query_layer` = [B*F, N*H] | |
query_layer = self.Q_layer(from_tensor_2d) | |
key_layer = self.K_layer(to_tensor_2d) | |
value_layer = self.Q_layer(to_tensor_2d) | |
# `query_layer` = [B, N, F, H] | |
query_layer = transpose_for_scores(query_layer, batch_size, | |
num_attention_heads, from_seq_length, | |
size_per_head) | |
# `key_layer` = [B, N, T, H] | |
key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, | |
to_seq_length, size_per_head) | |
# Take the dot product between "query" and "key" to get the raw | |
# attention scores. | |
# `attention_scores` = [B, N, F, T] | |
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) | |
attention_scores = tf.multiply(attention_scores, | |
1.0 / tf.math.sqrt(float(size_per_head))) | |
if attention_mask is not None: | |
attention_mask = tf.expand_dims(attention_mask, axis=1) | |
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 | |
adder = tf.expand_dims(adder, 1) | |
attention_scores += adder | |
# Normalize the attention scores to probabilities. | |
# `attention_probs` = [B, N, F, T] | |
attention_probs = tf.nn.softmax(attention_scores) | |
# This is actually dropping out entire tokens to attend to, which might | |
# seem a bit unusual, but is taken from the original Transformer paper. | |
attention_probs = self.dropout_layer(attention_probs, training=training) | |
# `value_layer` = [B, T, N, H] | |
value_layer = tf.reshape( | |
value_layer, | |
[batch_size, to_seq_length, num_attention_heads, size_per_head]) | |
# `value_layer` = [B, N, T, H] | |
value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) | |
# `context_layer` = [B, N, F, H] | |
context_layer = tf.matmul(attention_probs, value_layer) | |
# `context_layer` = [B, F, N, H] | |
context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) | |
# `context_layer` = [B, F, N*H] | |
context_layer = tf.reshape( | |
context_layer, | |
[batch_size, from_seq_length, num_attention_heads * size_per_head]) | |
if not self.return_attention_probs: | |
return context_layer | |
else: | |
return context_layer, attention_probs | |
def compute_output_shape(self, input_shape): | |
return (input_shape[0], self.params.num_heads * self.params.size_per_head) | |
class MultiHeadAttentionPooler(MultiHeadAttention): | |
""" Implements a pooling operation based on MultiHeadAttention | |
mechanism. """ | |
def __init__(self, pooling='max', **kwargs): | |
super(MultiHeadAttentionPooler, self).__init__(**kwargs) | |
self.pooling = pooling | |
def build(self, input_shape): | |
if self.pooling == 'average': | |
self.pool_layer = tf.keras.layers.GlobalAveragePooling1D() | |
elif self.pooling == 'max': | |
self.pool_layer = tf.keras.layers.GlobalMaxPooling1D() | |
elif self.pooling == 'last': | |
self.pool_layer = tf.keras.layers.Lambda(lambda x: x[:, -1, :]) | |
elif self.pooling == 'first': | |
self.pool_layer = tf.keras.layers.Lambda(lambda x: x[:, 0, :]) | |
else: | |
raise ValueError(f'pooling operation {self.pooling} not supported.') | |
super(MultiHeadAttentionPooler, self).build(input_shape) | |
def call(self, sequences, attention_mask=None, training=False): | |
shape = tf.shape(sequences) | |
pooled = self.pool_layer(sequences) | |
pooled = tf.expand_dims(pooled, 1) | |
context = super(MultiHeadAttentionPooler, self).call(pooled, sequences, attention_mask, training=training) | |
if not self.return_attention_probs: | |
context = tf.reshape(context, [-1, self.num_attention_heads * self.size_per_head]) | |
return context | |
else: | |
context, attention_probs = context | |
context = tf.reshape(context, [-1, self.num_attention_heads * self.size_per_head]) | |
attention_probs = tf.reshape(attention_probs, [-1, self.num_attention_heads, shape[1]]) | |
return context, attention_probs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment