Created
July 16, 2020 22:42
-
-
Save worthless443/af5ef1f6f5aff0bfd67f5dd869e63565 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ | |
# Written by Angelos Katharopoulos <[email protected]>, | |
# Apoorv Vyas <[email protected]> | |
# | |
"""The base attention layer performs all the query key value projections and | |
output projections leaving the implementation of the attention to the inner | |
attention module. | |
The transformer layers, however, are agnostic of the attention implementation | |
and any layer that implements the same interface can substitute for the | |
attention layer. | |
""" | |
from torch.nn import Linear, Module | |
class AttentionLayer(Module): | |
"""Implement the attention layer. Namely project the inputs to multi-head | |
queries, keys and values, call the attention implementation and then | |
reproject the output. | |
It can be thought of as a decorator (see decorator design patter) of an | |
attention layer. | |
Arguments | |
--------- | |
attention: Specific inner attention implementation that just computes a | |
weighted average of values given a similarity of queries and | |
keys. | |
d_model: The input feature dimensionality | |
n_heads: The number of heads for the multi head attention | |
d_keys: The dimensionality of the keys/queries | |
(default: d_model/n_heads) | |
d_values: The dimensionality of the values (default: d_model/n_heads) | |
""" | |
def __init__(self, attention, d_model, n_heads, d_keys=None, | |
d_values=None): | |
super(AttentionLayer, self).__init__() | |
# Fill d_keys and d_values | |
d_keys = d_keys or (d_model//n_heads) | |
d_values = d_values or (d_model//n_heads) | |
self.inner_attention = attention | |
self.query_projection = Linear(d_model, d_keys * n_heads) | |
self.key_projection = Linear(d_model, d_keys * n_heads) | |
self.value_projection = Linear(d_model, d_values * n_heads) | |
self.out_projection = Linear(d_values * n_heads, d_model) | |
self.n_heads = n_heads | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
"""Apply attention to the passed in queries/keys/values after | |
projecting them to multiple heads. | |
In the argument description we make use of the following sizes | |
- N: the batch size | |
- L: The maximum length of the queries | |
- S: The maximum length of the keys (the actual length per sequence | |
is given by the length mask) | |
- D: The input feature dimensionality passed in the constructor as | |
'd_model' | |
Arguments | |
--------- | |
queries: (N, L, D) The tensor containing the queries | |
keys: (N, S, D) The tensor containing the keys | |
values: (N, S, D) The tensor containing the values | |
attn_mask: An implementation of BaseMask that encodes where each | |
query can attend to | |
query_lengths: An implementation of BaseMask that encodes how | |
many queries each sequence in the batch consists of | |
key_lengths: An implementation of BaseMask that encodes how | |
many queries each sequence in the batch consists of | |
Returns | |
------- | |
The new value for each query as a tensor of shape (N, L, D). | |
""" | |
# Project the queries/keys/values | |
queries = self.query_projection(queries) | |
keys = self.key_projection(keys) | |
values = self.value_projection(values) | |
# Reshape them into many heads and compute the attention | |
N, L, D = queries.shape | |
_, S, _ = keys.shape | |
H = self.n_heads | |
new_values = self.inner_attention( | |
queries.view(N, L, H, -1), | |
keys.view(N, S, H, -1), | |
values.view(N, S, H, -1), | |
attn_mask, | |
query_lengths, | |
key_lengths | |
).view(N, L, -1) | |
# Project the output and return | |
return self.out_projection(new_values) | |
# | |
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ | |
# Written by Angelos Katharopoulos <[email protected]>, | |
# Apoorv Vyas <[email protected]> | |
# | |
"""Implement causally masked linear attention.""" | |
import torch | |
from torch.nn import Module | |
from fast_transformers.causal_product import causal_dot_product | |
def elu_feature_map(x): | |
return torch.nn.functional.elu(x) + 1 | |
def causal_linear(Q, K, V): | |
Q = Q.permute(0,2,1,3).contiguous() | |
K = K.permute(0,2,1,3).contiguous() | |
V = V.permute(0,2,1,3).contiguous() | |
V_new = causal_dot_product(Q, K, V) | |
return V_new.permute(0,2,1,3).contiguous() | |
class CausalLinearAttention(Module): | |
"""Implement causally masked attention using dot product of feature maps in | |
O(N D^2) complexity. | |
See fast_transformers.attention.linear_attention.LinearAttention for the | |
general concept of replacing the softmax with feature maps. In addition to | |
that, we also make use of the fact that causal masking is a triangular mask | |
which allows us to apply the masking and still compute the attention in O(N | |
D^2) complexity. | |
Arguments | |
--------- | |
feature_map: callable, a callable that applies the feature map to the | |
last dimension of a tensor (default: elu(x)+1) | |
eps: float, a small number to ensure the numerical stability of the | |
denominator (default: 1e-6) | |
""" | |
def __init__(self, feature_map=None, eps=1e-6): | |
super(CausalLinearAttention, self).__init__() | |
self.feature_map = feature_map or elu_feature_map | |
self.eps = eps | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
# Apply the feature map to the queries and keys | |
Q = self.feature_map(queries) | |
K = self.feature_map(keys) | |
# Apply the key padding mask and make sure the attn_mask is a | |
# lower triangular causal mask | |
if not attn_mask.lower_triangular: | |
raise RuntimeError(("CausalLinearAttention only supports full " | |
"lower triangular masks")) | |
K = K * key_lengths.float_matrix[:, :, None, None] | |
# TODO: Shall we divide the Q and K with a relatively large number to | |
# avoid numerical instabilities in computing the denominator? | |
# We used to divide each with the max norm of all q and k but | |
# that seems relatively costly for a simple normalization. | |
# Compute the normalizers | |
Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps) | |
# Compute the unnormalized result | |
V = causal_linear( | |
Q, | |
K, | |
values | |
) | |
return V * Z[:, :, :, None] | |
# | |
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ | |
# Written by Angelos Katharopoulos <[email protected]>, | |
# Apoorv Vyas <[email protected]> | |
# | |
"""Implement clustered self attention.""" | |
from math import sqrt | |
import torch | |
import torch.autograd | |
from torch.nn import Dropout, Module | |
from torch.nn.init import normal_ | |
from ..masking import FullMask | |
from ..aggregate import aggregate, broadcast | |
from ..clustering.hamming import cluster | |
from ..hashing import compute_hashes | |
class _GroupQueries(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, Q, clusters, counts): | |
factors = 1/counts.float() | |
q_grouped = aggregate(Q, clusters, factors) | |
ctx.save_for_backward(clusters, factors) | |
return q_grouped | |
@staticmethod | |
def backward(ctx, grad_q_grouped): | |
clusters, factors = ctx.saved_tensors | |
grad_q = broadcast(grad_q_grouped, clusters, factors) | |
return grad_q, None, None | |
class _BroadcastValues(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, v_grouped, clusters, counts): | |
factors = torch.ones_like(counts, dtype=v_grouped.dtype) | |
V = broadcast(v_grouped, clusters, factors) | |
ctx.save_for_backward(clusters, factors) | |
return V | |
@staticmethod | |
def backward(ctx, grad_v): | |
clusters, factors = ctx.saved_tensors | |
grad_v_grouped = aggregate(grad_v, clusters, factors) | |
return grad_v_grouped, None, None | |
class ClusteredAttention(Module): | |
"""Use LSH and clustering in the resulting Hamming space to group queries | |
that will have minimal L2 distance from each other. | |
Given the queries, keys, and values as Q, K, and V respectively, we | |
first cluster the queries in "C" groups and compute the "C" query centroids | |
Q_c. | |
We now use to the centroids Q_c to compute the attention using: | |
V'_c = softmax(Q_c.mm(K.t()), dim=-1).mm(V). | |
Now the computed values V'_c are "broadcasted" back to the query members | |
of the corresponding cluster. | |
Arguments | |
--------- | |
clusters: How many clusters to group the queries into | |
iterations: The number of lloyd iterations to perform (default: 10) | |
bits: How many bits to use for the hash (default: 32) | |
hash_bias: If true, hamming distance proportional to L2 distance | |
If false, hamming distance proportional to cosine distance | |
(default: True) | |
softmax_temp: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
dropout_rate: The dropout rate to apply to the attention (default: 0.1) | |
""" | |
def __init__(self, clusters, iterations=10, bits=32, | |
hash_bias=True, softmax_temp=None, dropout_rate=0.1): | |
super(ClusteredAttention, self).__init__() | |
self.clusters = clusters | |
self.iterations = iterations | |
self.bits = bits | |
self.hash_bias = hash_bias | |
self.softmax_temp = softmax_temp | |
self.dropout = Dropout(dropout_rate) | |
def _create_query_groups(self, Q, query_lengths): | |
N, H, L, E = Q.shape | |
# Compute the hashes for all the queries | |
planes = Q.new_empty((self.bits, E+1)) | |
normal_(planes) | |
if not self.hash_bias: | |
planes[:, -1] = 0 | |
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L) | |
# Cluster the hashes and return the cluster index per query | |
groups = cluster( | |
hashes, | |
query_lengths._lengths.int(), | |
clusters=self.clusters, | |
iterations=self.iterations, | |
bits=self.bits | |
) | |
return groups | |
def _group_queries(self, Q, groups): | |
"""Aggregate the Qs based on the index of cluster they belong to. Make | |
sure to allow for gradient propagation backwards from the grouped | |
queries to each query.""" | |
q_grouped = _GroupQueries.apply(Q, *groups) | |
return q_grouped | |
def _broadcast_values(self, V, groups): | |
"""Broadcast the values back to the correct positions but make sure | |
that the gradient flows properly.""" | |
V_new = _BroadcastValues.apply(V.contiguous(), *groups) | |
V_new = V_new.permute(0, 2, 1, 3).contiguous() | |
return V_new | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
# Make sure that there is no attention mask | |
assert attn_mask.all_ones, ("Clustered attention cannot use an " | |
"arbitrary attention mask.") | |
queries = queries.permute(0,2,1,3).contiguous() | |
keys = keys.permute(0,2,1,3).contiguous() | |
values = values.permute(0,2,1,3).contiguous() | |
N, H, L, E = queries.shape | |
softmax_temp = self.softmax_temp or 1./sqrt(E) | |
# Cluster the queries into groups | |
groups = self._create_query_groups(queries, query_lengths) | |
Q_grouped = self._group_queries(queries, groups) | |
# Compute the attention | |
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys) | |
QK = QK + key_lengths.additive_matrix[:, None, None, :] | |
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) | |
V = torch.einsum("nhls,nhsd->nhld", A, values) | |
# Broadcast grouped attention | |
return self._broadcast_values(V, groups) | |
# | |
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ | |
# Written by Angelos Katharopoulos <[email protected]>, | |
# Apoorv Vyas <[email protected]> | |
# | |
"""Implement a self attention that delegates to full attention or another | |
attention depending on the input sequence length.""" | |
import torch | |
from torch.nn import Module | |
from .full_attention import FullAttention | |
class ConditionalFullAttention(Module): | |
""""Delegate to full attention if the input sequence is short. | |
Arguments | |
--------- | |
other_attention: Use the passed attention module if the sequence is | |
longer than 'length_limit'. | |
length_limit: An integer denoting the maximum sequence length to | |
consider. | |
softmax_temp: See fast_transformers.attention.full_attention. | |
dropout_rate: See fast_transformers.attention.full_attention. | |
""" | |
def __init__(self, other_attention, length_limit=512, softmax_temp=None, | |
dropout_rate=0.1): | |
super(ConditionalFullAttention, self).__init__() | |
self.full_attention = FullAttention(softmax_temp, dropout_rate) | |
self.other_attention = other_attention | |
self.length_limit = length_limit | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
# Extract some shapes to compare with the length limit | |
L = queries.shape[1] | |
S = values.shape[1] | |
if L > self.length_limit or S > self.length_limit: | |
return self.other_attention(queries, keys, values, attn_mask, | |
query_lengths, key_lengths) | |
else: | |
return self.full_attention(queries, keys, values, attn_mask, | |
query_lengths, key_lengths) | |
# | |
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ | |
# Written by Angelos Katharopoulos <[email protected]>, | |
# Apoorv Vyas <[email protected]> | |
# | |
"""Implement the oracle top-k attention. The top-k keys are exact ones. | |
MultiHeadAttention module. Note that this module is to be used in conjuction | |
with the AttentionLayer in order to work.""" | |
from math import sqrt | |
import torch | |
from torch.nn import Dropout, Module | |
class ExactTopKAttention(Module): | |
"""Implement the oracle top-k softmax attention. | |
Arguments | |
--------- | |
top-k: The top k keys to attend to (default: 32) | |
softmax_temp: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
dropout_rate: The dropout rate to apply to the attention (default: 0.1) | |
""" | |
def __init__(self, topk=32, softmax_temp=None, dropout_rate=0.1): | |
super(ExactTopKAttention, self).__init__() | |
self.topk = topk | |
self.softmax_temp = softmax_temp | |
self.dropout = Dropout(dropout_rate) | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
# Extract some shapes and compute the temperature | |
N, L, H, E = queries.shape | |
_, S, _, D = values.shape | |
softmax_temp = self.softmax_temp or 1./sqrt(E) | |
# Compute the unnormalized attention and apply the masks | |
QK = torch.einsum("nlhe,nshe->nhls", queries, keys) | |
topk = min(self.topk, L) | |
if not attn_mask.all_ones: | |
QK = QK + attn_mask.additive_matrix | |
QK = QK + key_lengths.additive_matrix[:, None, None] | |
topk_values, topk_idx = torch.topk(QK, topk, sorted=False, dim=-1) | |
mask = QK.new_ones(QK.shape) * float("-inf") | |
mask[ | |
torch.arange(N, device=QK.device).view(N, 1, 1, 1), | |
torch.arange(H, device=QK.device).view(1, H, 1, 1), | |
torch.arange(L, device=QK.device).view(1, 1, L, 1), | |
topk_idx, | |
] = 0. | |
QK = QK + mask | |
# Compute the attention and the weighted average | |
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) | |
V = torch.einsum("nhls,nshd->nlhd", A, values) | |
# Make sure that what we return is contiguous | |
return V.contiguous() | |
# | |
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ | |
# Written by Angelos Katharopoulos <[email protected]>, | |
# Apoorv Vyas <[email protected]> | |
# | |
"""Implement the full attention similar to the one implemented by PyTorch's | |
MultiHeadAttention module. Note that this module is to be used in conjuction | |
with the `fast_transformers.attention.attention_layer.AttentionLayer` in order | |
to work.""" | |
from math import sqrt | |
import torch | |
from torch.nn import Dropout, Module | |
class FullAttention(Module): | |
"""Implement the scaled dot product attention with softmax. | |
Arguments | |
--------- | |
softmax_temp: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
dropout_rate: The dropout rate to apply to the attention (default: 0.1) | |
""" | |
def __init__(self, softmax_temp=None, dropout_rate=0.1): | |
super(FullAttention, self).__init__() | |
self.softmax_temp = softmax_temp | |
self.dropout = Dropout(dropout_rate) | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
"""Implements the multihead softmax attention. | |
Arguments | |
--------- | |
queries: (N, L, H, E) The tensor containing the queries | |
keys: (N, S, H, E) The tensor containing the keys | |
values: (N, S, H, D) The tensor containing the values | |
attn_mask: An implementation of BaseMask that encodes where each | |
query can attend to | |
query_lengths: An implementation of BaseMask that encodes how | |
many queries each sequence in the batch consists of | |
key_lengths: An implementation of BaseMask that encodes how | |
many queries each sequence in the batch consists of | |
""" | |
# Extract some shapes and compute the temperature | |
N, L, H, E = queries.shape | |
_, S, _, D = values.shape | |
softmax_temp = self.softmax_temp or 1./sqrt(E) | |
# Compute the unnormalized attention and apply the masks | |
QK = torch.einsum("nlhe,nshe->nhls", queries, keys) | |
if not attn_mask.all_ones: | |
QK = QK + attn_mask.additive_matrix | |
QK = QK + key_lengths.additive_matrix[:, None, None] | |
# Compute the attention and the weighted average | |
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) | |
V = torch.einsum("nhls,nshd->nlhd", A, values) | |
# Make sure that what we return is contiguous | |
return V.contiguous() | |
# | |
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ | |
# Written by Angelos Katharopoulos <[email protected]>, | |
# Apoorv Vyas <[email protected]> | |
# | |
"""Implement improved clustered self attention.""" | |
from math import sqrt | |
import torch | |
import torch.autograd | |
from torch.nn import Dropout, Module | |
from torch.nn.init import normal_ | |
from ..masking import FullMask | |
from ..aggregate import aggregate, broadcast | |
from ..clustering.hamming import cluster | |
from ..hashing import compute_hashes | |
from ..sparse_product import sparse_dot_product, sparse_weighted_average | |
from ..sparse_product import clustered_sparse_dot_product, \ | |
clustered_sparse_weighted_average | |
class _GroupQueries(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, Q, clusters, counts): | |
factors = 1/counts.float() | |
q_grouped = aggregate(Q, clusters, factors) | |
ctx.save_for_backward(clusters, factors) | |
return q_grouped | |
@staticmethod | |
def backward(ctx, grad_q_grouped): | |
clusters, factors = ctx.saved_tensors | |
grad_q = broadcast(grad_q_grouped, clusters, factors) | |
return grad_q, None, None | |
class _BroadcastValues(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, v_grouped, clusters, counts): | |
factors = torch.ones_like(counts, dtype=v_grouped.dtype) | |
V = broadcast(v_grouped, clusters, factors) | |
ctx.save_for_backward(clusters, factors) | |
return V | |
@staticmethod | |
def backward(ctx, grad_v): | |
clusters, factors = ctx.saved_tensors | |
grad_v_grouped = aggregate(grad_v, clusters, factors) | |
return grad_v_grouped, None, None | |
class ImprovedClusteredAttention(Module): | |
""" | |
Immproved clustered attention approximation by recompution attention | |
for each query with the top-k keys for the corresponding cluster. | |
Given the queries, keys, and values as Q, K, and V respectively, we | |
first cluster the queries in "C" groups and compute the "C" query centroids | |
Q_c. | |
We now use to the centroids Q_c to identify the top-k keys with highest | |
dot products. | |
Subsequently, for each query we compute the sparse dot product with | |
the corresponding top-k keys to improve the attention approximation. | |
Arguments | |
--------- | |
clusters: How many clusters to group the queries into | |
iterations: The number of lloyd iterations to perform (default: 10) | |
bits: How many bits to use for the hash (default: 32) | |
hash_bias: If true, hamming distance proportional to L2 distance | |
If false, hamming distance proportional to cosine distance | |
(default: True) | |
topk: Number of top-k keys to for improved approximation (default: 32) | |
softmax_temp: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
dropout_rate: The dropout rate to apply to the attention (default: 0.1) | |
""" | |
def __init__(self, clusters, iterations=10, bits=32, | |
hash_bias=True, topk=32, softmax_temp=None, dropout_rate=0.1): | |
super(ImprovedClusteredAttention, self).__init__() | |
self.clusters = clusters | |
self.iterations = iterations | |
self.bits = bits | |
self.hash_bias = hash_bias | |
self.topk = topk | |
self.softmax_temp = softmax_temp | |
self.dropout = Dropout(dropout_rate) | |
def _create_query_groups(self, Q, query_lengths): | |
N, H, L, E = Q.shape | |
# Compute the hashes for all the queries | |
planes = Q.new_empty((self.bits, E+1)) | |
normal_(planes) | |
if not self.hash_bias: | |
planes[:, -1] = 0 | |
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L) | |
# Cluster the hashes and return the cluster index per query | |
clusters, counts = cluster( | |
hashes, | |
query_lengths._lengths.int(), | |
clusters=self.clusters, | |
iterations=self.iterations, | |
bits=self.bits | |
) | |
return clusters, counts | |
def _topk_attention(self, Q, K, V, | |
clusters, counts, | |
topk, topk_values, | |
A_bottomk, softmax_temp, | |
query_lengths): | |
"""Return the attention with just the topk heads.""" | |
# Extract some indices | |
N, H, L, E = Q.shape | |
_, _, S, _ = K.shape | |
_, _, C, k = topk.shape | |
# We need to pass the output tensor to initialize to 0 | |
QK = clustered_sparse_dot_product( | |
Q, K, topk, | |
clusters, counts, | |
query_lengths._lengths.int() | |
) | |
# We need to mask the topk dot products if topk > input_length | |
QK = QK.masked_fill( | |
torch.isinf(topk_values[:,0,0,:]).view(N, 1, 1, k), | |
float("-inf") | |
) | |
A = torch.softmax(softmax_temp * QK, dim=-1) | |
assert A_bottomk.is_contiguous() | |
A_bottomk = broadcast( | |
A_bottomk.unsqueeze(3), | |
clusters, | |
torch.ones_like(counts, dtype=torch.float32) | |
) | |
A = A * (1.0 - A_bottomk) | |
A = self.dropout(A) | |
assert A.is_contiguous() | |
V_new = clustered_sparse_weighted_average(A, V, topk, clusters) | |
return V_new | |
def _broadcast_values(self, V, clusters, counts): | |
"""Broadcast the values back to the correct positions but make sure | |
that the gradient flows properly.""" | |
V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts) | |
return V_new | |
def _bottomk_attention(self, QK, V, clusters, counts, topk, softmax_temp): | |
"""Return the attention with just the bottomk keys.""" | |
N, H, C, S = QK.shape | |
A = torch.softmax(softmax_temp * QK, dim=-1) | |
mask = QK.new_ones(QK.shape) | |
mask[ | |
torch.arange(N, device=QK.device).view(N, 1, 1, 1), | |
torch.arange(H, device=QK.device).view(1, H, 1, 1), | |
torch.arange(C, device=QK.device).view(1, 1, C, 1), | |
topk, | |
] = 0 | |
A = A * mask | |
A_bottomk = A.sum(-1) | |
A = self.dropout(A) | |
# Compute the values | |
V_new = torch.einsum("nhls,nhse->nhle", A, V) | |
# Broadcast the values back depending on the groups | |
V_new = self._broadcast_values(V_new, clusters, counts) | |
return V_new, A_bottomk | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
# Make sure that there is no attention mask | |
assert attn_mask.all_ones, ("Improved-clustered attention cannot " | |
"use an arbitrary attention mask.") | |
queries = queries.permute(0,2,1,3).contiguous() | |
keys = keys.permute(0,2,1,3).contiguous() | |
values = values.permute(0,2,1,3).contiguous() | |
N, H, L, E = queries.shape | |
softmax_temp = self.softmax_temp or 1./sqrt(E) | |
# Cluster the queries into groups | |
clusters, counts = self._create_query_groups(queries, query_lengths) | |
Q_grouped = _GroupQueries.apply(queries, clusters, counts) | |
# Compute the attention | |
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys) | |
QK = QK + key_lengths.additive_matrix[:, None, None, :] | |
topk_values, topk = torch.topk(QK, self.topk, sorted=False, dim=-1) | |
assert topk.is_contiguous() | |
# Now compute the attention with only the bottom keys | |
V_bottomk, A_bottomk = self._bottomk_attention( | |
QK, values, | |
clusters, counts, | |
topk, | |
softmax_temp | |
) | |
# Now compute the attention with only the top keys | |
V_topk = self._topk_attention( | |
queries, keys, values, | |
clusters, counts, | |
topk, topk_values, | |
A_bottomk, | |
softmax_temp, | |
query_lengths | |
) | |
V_new = V_topk + V_bottomk | |
return V_new.permute(0, 2, 1, 3).contiguous() | |
# | |
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ | |
# Written by Angelos Katharopoulos <[email protected]>, | |
# Apoorv Vyas <[email protected]> | |
# | |
"""Implement improved clustered causal self attention.""" | |
from math import sqrt | |
import torch | |
import torch.autograd | |
from torch.nn import Dropout, Module | |
from torch.nn.init import normal_ | |
from ..masking import FullMask | |
from ..aggregate import aggregate, broadcast | |
from ..clustering.hamming import cluster | |
from ..hashing import compute_hashes | |
from ..sparse_product import sparse_dot_product, sparse_weighted_average | |
from ..sparse_product import clustered_sparse_dot_product, \ | |
clustered_sparse_weighted_average | |
class _GroupQueries(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, Q, clusters, counts): | |
factors = 1/counts.float() | |
q_grouped = aggregate(Q, clusters, factors) | |
ctx.save_for_backward(clusters, factors) | |
return q_grouped | |
@staticmethod | |
def backward(ctx, grad_q_grouped): | |
clusters, factors = ctx.saved_tensors | |
grad_q = broadcast(grad_q_grouped, clusters, factors) | |
return grad_q, None, None | |
class _BroadcastValues(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, v_grouped, clusters, counts): | |
factors = torch.ones_like(counts, dtype=v_grouped.dtype) | |
V = broadcast(v_grouped, clusters, factors) | |
ctx.save_for_backward(clusters, factors) | |
return V | |
@staticmethod | |
def backward(ctx, grad_v): | |
clusters, factors = ctx.saved_tensors | |
grad_v_grouped = aggregate(grad_v, clusters, factors) | |
return grad_v_grouped, None, None | |
class ImprovedClusteredCausalAttention(Module): | |
""" | |
Immproved clustered causal attention approximation by recomputing attention | |
for each query with the top-k keys for the corresponding cluster. | |
Given the queries, keys, and values as Q, K, and V respectively, we | |
first cluster the queries in "C" groups and compute the "C" query centroids | |
Q_c. | |
We now use to the centroids Q_c to identify the top-k keys with highest | |
dot products. | |
Subsequently, for each query we compute the sparse dot product with | |
the corresponding top-k keys to improve the attention approximation. | |
Key difference with improved clustered attention is that we only use | |
top-k keys with causal mask, we do not compute attention on the | |
bottom-k keys. | |
Arguments | |
--------- | |
clusters: How many clusters to group the queries into | |
iterations: The number of lloyd iterations to perform (default: 10) | |
bits: How many bits to use for the hash (default: 32) | |
hash_bias: If true, hamming distance proportional to L2 distance | |
If false, hamming distance proportional to cosine distance | |
(default: True) | |
topk: Number of top-k keys to for improved approximation (default: 32) | |
softmax_temp: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
dropout_rate: The dropout rate to apply to the attention (default: 0.1) | |
""" | |
def __init__(self, clusters, iterations=10, bits=32, | |
hash_bias=True, topk=32, softmax_temp=None, dropout_rate=0.1): | |
super(ImprovedClusteredCausalAttention, self).__init__() | |
self.clusters = clusters | |
self.iterations = iterations | |
self.bits = bits | |
self.hash_bias = hash_bias | |
self.topk = topk | |
self.softmax_temp = softmax_temp | |
self.dropout = Dropout(dropout_rate) | |
def _create_query_groups(self, Q, query_lengths): | |
N, H, L, E = Q.shape | |
# Compute the hashes for all the queries | |
planes = Q.new_empty((self.bits, E+1)) | |
normal_(planes) | |
if not self.hash_bias: | |
planes[:, -1] = 0 | |
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L) | |
# Cluster the hashes and return the cluster index per query | |
clusters, counts = cluster( | |
hashes, | |
query_lengths._lengths.int(), | |
clusters=self.clusters, | |
iterations=self.iterations, | |
bits=self.bits | |
) | |
return clusters, counts | |
def _topk_attention(self, Q, K, V, | |
clusters, counts, | |
topk, topk_values, | |
softmax_temp, | |
query_lengths): | |
"""Return the attention with just the topk heads.""" | |
# Extract some indices | |
N, H, L, E = Q.shape | |
_, _, S, _ = K.shape | |
_, _, C, k = topk.shape | |
# We need to pass the output tensor to initialize to 0 | |
QK = clustered_sparse_dot_product( | |
Q, K, topk, | |
clusters, counts, | |
query_lengths._lengths.int() | |
) | |
# We need to mask the topk dot products if topk > input_length | |
QK = QK.masked_fill( | |
torch.isinf(topk_values[:,0,0,:]).view(N, 1, 1, k), | |
float("-inf") | |
) | |
# We need to mask out the future | |
assert topk.is_contiguous() | |
topk_broadcast = broadcast( | |
topk.float(), | |
clusters, | |
torch.ones_like(counts, dtype=torch.float32) | |
) | |
QK = QK.masked_fill( | |
topk_broadcast.long() > torch.arange(L, device=QK.device).view(1, 1, L, 1), | |
float("-1e7") | |
) | |
A = torch.softmax(softmax_temp * QK, dim=-1) | |
A = self.dropout(A) | |
assert A.is_contiguous() | |
V_new = clustered_sparse_weighted_average(A, V, topk, clusters) | |
return V_new | |
def _broadcast_values(self, V, clusters, counts): | |
"""Broadcast the values back to the correct positions but make sure | |
that the gradient flows properly.""" | |
V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts) | |
return V_new | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
# Apply the key padding mask and make sure the attn_mask is a | |
# lower triangular causal mask | |
if not attn_mask.lower_triangular: | |
raise RuntimeError(("ImprovedClusteredCausalAttention only supports full " | |
"lower triangular masks")) | |
queries = queries.permute(0,2,1,3).contiguous() | |
keys = keys.permute(0,2,1,3).contiguous() | |
values = values.permute(0,2,1,3).contiguous() | |
N, H, L, E = queries.shape | |
softmax_temp = self.softmax_temp or 1./sqrt(E) | |
# Cluster the queries into groups | |
clusters, counts = self._create_query_groups(queries, query_lengths) | |
Q_grouped = _GroupQueries.apply(queries, clusters, counts) | |
# Compute the attention | |
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys) | |
QK = QK + key_lengths.additive_matrix[:, None, None, :] | |
topk_values, topk = torch.topk(QK, self.topk, sorted=False, dim=-1) | |
assert topk.is_contiguous() | |
# Compute the attention with only the top keys | |
V_topk = self._topk_attention( | |
queries, keys, values, | |
clusters, counts, | |
topk, topk_values, | |
softmax_temp, | |
query_lengths | |
) | |
V_new = V_topk | |
return V_new.permute(0, 2, 1, 3).contiguous() | |
# | |
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ | |
# Written by Angelos Katharopoulos <[email protected]>, | |
# Apoorv Vyas <[email protected]> | |
# | |
"""Implement unmasked linear attention.""" | |
import torch | |
from torch.nn import Module | |
def elu_feature_map(x): | |
return torch.nn.functional.elu(x) + 1 | |
class LinearAttention(Module): | |
"""Implement unmasked attention using dot product of feature maps in | |
O(N D^2) complexity. | |
Given the queries, keys and values as Q, K, V instead of computing | |
V' = softmax(Q.mm(K.t()), dim=-1).mm(V), | |
we make use of a feature map function Φ(.) and perform the following | |
computation | |
V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V). | |
The above can be computed in O(N D^2) complexity where D is the | |
dimensionality of Q, K and V and N is the sequence length. Depending on the | |
feature map, however, the complexity of the attention might be limited. | |
Arguments | |
--------- | |
feature_map: callable, a callable that applies the feature map to the | |
last dimension of a tensor (default: elu(x)+1) | |
eps: float, a small number to ensure the numerical stability of the | |
denominator (default: 1e-6) | |
""" | |
def __init__(self, feature_map=None, eps=1e-6): | |
super(LinearAttention, self).__init__() | |
self.feature_map = feature_map or elu_feature_map | |
self.eps = eps | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
# Apply the feature map to the queries and keys | |
Q = self.feature_map(queries) | |
K = self.feature_map(keys) | |
# Apply the key padding mask and make sure that the attn_mask is | |
# all_ones | |
if not attn_mask.all_ones: | |
raise RuntimeError(("LinearAttention does not support arbitrary " | |
"attention masks")) | |
K = K * key_lengths.float_matrix[:, :, None, None] | |
# Compute the KV matrix, namely the dot product of keys and values so | |
# that we never explicitly compute the attention matrix and thus | |
# decrease the complexity | |
KV = torch.einsum("nshd,nshm->nhmd", K, values) | |
# Compute the normalizer | |
Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps) | |
# Finally compute and return the new values | |
V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z) | |
return V.contiguous() | |
# | |
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ | |
# Written by Angelos Katharopoulos <[email protected]>, | |
# Apoorv Vyas <[email protected]> | |
# | |
"""Implement the Reformer attention from the paper | |
"Reformer the efficient transformer".""" | |
from math import sqrt | |
import torch | |
from torch.nn import Dropout, Module | |
from torch.nn.init import normal_ | |
from ..masking import FullMask | |
class ReformerAttention(Module): | |
"""Implement the attention module of the paper "Reformer the efficient | |
transformer" | |
Arguments | |
--------- | |
chunk_size : Chunk size for each block (default: 32) | |
bits : Number of bits for hashing (default: 8) | |
rounds : Number of rounds of attention computation (default: 4) | |
masked : If true, the query does not attend to itsself (default: False) | |
softmax_temp: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
dropout_rate: The dropout rate to apply to the attention (default: 0.1) | |
""" | |
def __init__(self, chunk_size=32, bits=8, rounds=4, masked=False, | |
softmax_temp=None, dropout_rate=0.1): | |
super(ReformerAttention, self).__init__() | |
self.chunk_size = chunk_size | |
self.bits = bits | |
self.rounds = rounds | |
self.masked = masked | |
self.softmax_temp = softmax_temp | |
self.dropout = Dropout(dropout_rate) | |
def _normalize(self, x): | |
norms = torch.sqrt(torch.einsum("nlhe,nlhe->nlh", x, x)) | |
x_normed = x / norms.unsqueeze(-1) | |
return x_normed | |
def _look_back(self, x): | |
xshape = x.shape | |
return torch.cat([ | |
x.new_zeros((xshape[0], 1) + xshape[2:]), | |
torch.repeat_interleave(x, 2, dim=1)[:,:-1] | |
], dim=1).view(xshape[0], xshape[1], 2*xshape[2], *xshape[3:]) | |
def _reformer_round(self, Q, K, V, mask, softmax_temp): | |
# Hash the queries | |
N, L, H, E = Q.shape | |
planes = Q.new_empty(self.bits, E) | |
normal_(planes) | |
projected = torch.einsum("nlhe,be->nlhb", K, planes) | |
hashes = torch.argmax( | |
torch.cat([projected, -projected], dim=-1), | |
dim=-1 | |
) | |
# Sort the queries in order to group them | |
group = torch.argsort(hashes, dim=1) | |
invert_group = torch.empty_like(group) | |
batch_indices = torch.arange(N, device=hashes.device).view(N, 1, 1) | |
sequence_indices = torch.arange(L, device=hashes.device).view(1, L, 1) | |
head_indices = torch.arange(H, device=hashes.device).view(1, 1, H) | |
invert_group[batch_indices, group, head_indices] = sequence_indices | |
group = group.view(N, -1, self.chunk_size, H) | |
invert_group = invert_group.view(N, -1, self.chunk_size, H) | |
batch_indices = batch_indices.unsqueeze(1) | |
head_indices = head_indices.unsqueeze(0) | |
# Reorder Q, V and mask | |
Q_grouped = Q[batch_indices, group, head_indices] | |
K_grouped = K[batch_indices, group, head_indices] | |
V_grouped = V[batch_indices, group, head_indices] | |
mask_grouped = mask[ | |
batch_indices.unsqueeze(1), | |
group.unsqueeze(3), | |
self._look_back(group).unsqueeze(2) | |
] | |
mask_grouped[:, 0, :, :Q_grouped.shape[2]] = float("-inf") | |
# When everything is masked just unmask everything because it doesn't | |
# matter what the output is at those positions | |
# This is to avoid inf/nans in the new values at masked positions | |
infmask = torch.isinf(mask_grouped) | |
infmask = torch.all(infmask, dim=3, keepdims=True) | |
mask_grouped = mask_grouped.masked_fill(infmask, 0.) | |
# Attention | |
K_grouped = self._look_back(K_grouped) | |
QQ = torch.einsum("nblhe,nbshe->nbhls", Q_grouped, K_grouped) | |
QQ = QQ + mask_grouped.permute(0, 1, 4, 2, 3) | |
A = torch.softmax(softmax_temp * QQ, dim=-1) | |
A = self.dropout(A) | |
# Values | |
V_grouped = self._look_back(V_grouped) | |
V_new = torch.einsum("nbhls,nbshe->nblhe", A, V_grouped) | |
V_new = V_new.contiguous().view(N, -1, H, E) | |
V_new = V_new[batch_indices, invert_group, head_indices] | |
V_new = V_new.contiguous().view(N, L, H, E) | |
return V_new | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
# Extract the dimensions of query, key, value | |
N, L, H, E = queries.shape | |
softmax_temp = self.softmax_temp or 1./sqrt(E) | |
# Create the mask | |
mask = key_lengths.additive_matrix.unsqueeze(1).expand(N, L, L) | |
if self.masked: | |
mask = mask + torch.eye(L, device=queries.device).unsqueeze(0)*float(-1e9) | |
if not attn_mask.all_ones: | |
mask = mask + attn_mask.additive_matrix.unsqueeze(0) | |
# Get normalized Queries as Keys | |
K = self._normalize(queries) | |
# Zero the masked out keys | |
K = K * key_lengths.float_matrix.view(N, L, 1, 1) | |
V_new = 0 | |
factor = 1/self.rounds | |
for i in range(self.rounds): | |
V_new = V_new + \ | |
factor * self._reformer_round(queries, K, values, mask, softmax_temp) | |
return V_new |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment