Skip to content

Instantly share code, notes, and snippets.

@worthless443
Created July 16, 2020 22:42
Show Gist options
  • Save worthless443/af5ef1f6f5aff0bfd67f5dd869e63565 to your computer and use it in GitHub Desktop.
Save worthless443/af5ef1f6f5aff0bfd67f5dd869e63565 to your computer and use it in GitHub Desktop.
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <[email protected]>,
# Apoorv Vyas <[email protected]>
#
"""The base attention layer performs all the query key value projections and
output projections leaving the implementation of the attention to the inner
attention module.
The transformer layers, however, are agnostic of the attention implementation
and any layer that implements the same interface can substitute for the
attention layer.
"""
from torch.nn import Linear, Module
class AttentionLayer(Module):
"""Implement the attention layer. Namely project the inputs to multi-head
queries, keys and values, call the attention implementation and then
reproject the output.
It can be thought of as a decorator (see decorator design patter) of an
attention layer.
Arguments
---------
attention: Specific inner attention implementation that just computes a
weighted average of values given a similarity of queries and
keys.
d_model: The input feature dimensionality
n_heads: The number of heads for the multi head attention
d_keys: The dimensionality of the keys/queries
(default: d_model/n_heads)
d_values: The dimensionality of the values (default: d_model/n_heads)
"""
def __init__(self, attention, d_model, n_heads, d_keys=None,
d_values=None):
super(AttentionLayer, self).__init__()
# Fill d_keys and d_values
d_keys = d_keys or (d_model//n_heads)
d_values = d_values or (d_model//n_heads)
self.inner_attention = attention
self.query_projection = Linear(d_model, d_keys * n_heads)
self.key_projection = Linear(d_model, d_keys * n_heads)
self.value_projection = Linear(d_model, d_values * n_heads)
self.out_projection = Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
"""Apply attention to the passed in queries/keys/values after
projecting them to multiple heads.
In the argument description we make use of the following sizes
- N: the batch size
- L: The maximum length of the queries
- S: The maximum length of the keys (the actual length per sequence
is given by the length mask)
- D: The input feature dimensionality passed in the constructor as
'd_model'
Arguments
---------
queries: (N, L, D) The tensor containing the queries
keys: (N, S, D) The tensor containing the keys
values: (N, S, D) The tensor containing the values
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
query_lengths: An implementation of BaseMask that encodes how
many queries each sequence in the batch consists of
key_lengths: An implementation of BaseMask that encodes how
many queries each sequence in the batch consists of
Returns
-------
The new value for each query as a tensor of shape (N, L, D).
"""
# Project the queries/keys/values
queries = self.query_projection(queries)
keys = self.key_projection(keys)
values = self.value_projection(values)
# Reshape them into many heads and compute the attention
N, L, D = queries.shape
_, S, _ = keys.shape
H = self.n_heads
new_values = self.inner_attention(
queries.view(N, L, H, -1),
keys.view(N, S, H, -1),
values.view(N, S, H, -1),
attn_mask,
query_lengths,
key_lengths
).view(N, L, -1)
# Project the output and return
return self.out_projection(new_values)
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <[email protected]>,
# Apoorv Vyas <[email protected]>
#
"""Implement causally masked linear attention."""
import torch
from torch.nn import Module
from fast_transformers.causal_product import causal_dot_product
def elu_feature_map(x):
return torch.nn.functional.elu(x) + 1
def causal_linear(Q, K, V):
Q = Q.permute(0,2,1,3).contiguous()
K = K.permute(0,2,1,3).contiguous()
V = V.permute(0,2,1,3).contiguous()
V_new = causal_dot_product(Q, K, V)
return V_new.permute(0,2,1,3).contiguous()
class CausalLinearAttention(Module):
"""Implement causally masked attention using dot product of feature maps in
O(N D^2) complexity.
See fast_transformers.attention.linear_attention.LinearAttention for the
general concept of replacing the softmax with feature maps. In addition to
that, we also make use of the fact that causal masking is a triangular mask
which allows us to apply the masking and still compute the attention in O(N
D^2) complexity.
Arguments
---------
feature_map: callable, a callable that applies the feature map to the
last dimension of a tensor (default: elu(x)+1)
eps: float, a small number to ensure the numerical stability of the
denominator (default: 1e-6)
"""
def __init__(self, feature_map=None, eps=1e-6):
super(CausalLinearAttention, self).__init__()
self.feature_map = feature_map or elu_feature_map
self.eps = eps
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Apply the feature map to the queries and keys
Q = self.feature_map(queries)
K = self.feature_map(keys)
# Apply the key padding mask and make sure the attn_mask is a
# lower triangular causal mask
if not attn_mask.lower_triangular:
raise RuntimeError(("CausalLinearAttention only supports full "
"lower triangular masks"))
K = K * key_lengths.float_matrix[:, :, None, None]
# TODO: Shall we divide the Q and K with a relatively large number to
# avoid numerical instabilities in computing the denominator?
# We used to divide each with the max norm of all q and k but
# that seems relatively costly for a simple normalization.
# Compute the normalizers
Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps)
# Compute the unnormalized result
V = causal_linear(
Q,
K,
values
)
return V * Z[:, :, :, None]
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <[email protected]>,
# Apoorv Vyas <[email protected]>
#
"""Implement clustered self attention."""
from math import sqrt
import torch
import torch.autograd
from torch.nn import Dropout, Module
from torch.nn.init import normal_
from ..masking import FullMask
from ..aggregate import aggregate, broadcast
from ..clustering.hamming import cluster
from ..hashing import compute_hashes
class _GroupQueries(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, clusters, counts):
factors = 1/counts.float()
q_grouped = aggregate(Q, clusters, factors)
ctx.save_for_backward(clusters, factors)
return q_grouped
@staticmethod
def backward(ctx, grad_q_grouped):
clusters, factors = ctx.saved_tensors
grad_q = broadcast(grad_q_grouped, clusters, factors)
return grad_q, None, None
class _BroadcastValues(torch.autograd.Function):
@staticmethod
def forward(ctx, v_grouped, clusters, counts):
factors = torch.ones_like(counts, dtype=v_grouped.dtype)
V = broadcast(v_grouped, clusters, factors)
ctx.save_for_backward(clusters, factors)
return V
@staticmethod
def backward(ctx, grad_v):
clusters, factors = ctx.saved_tensors
grad_v_grouped = aggregate(grad_v, clusters, factors)
return grad_v_grouped, None, None
class ClusteredAttention(Module):
"""Use LSH and clustering in the resulting Hamming space to group queries
that will have minimal L2 distance from each other.
Given the queries, keys, and values as Q, K, and V respectively, we
first cluster the queries in "C" groups and compute the "C" query centroids
Q_c.
We now use to the centroids Q_c to compute the attention using:
V'_c = softmax(Q_c.mm(K.t()), dim=-1).mm(V).
Now the computed values V'_c are "broadcasted" back to the query members
of the corresponding cluster.
Arguments
---------
clusters: How many clusters to group the queries into
iterations: The number of lloyd iterations to perform (default: 10)
bits: How many bits to use for the hash (default: 32)
hash_bias: If true, hamming distance proportional to L2 distance
If false, hamming distance proportional to cosine distance
(default: True)
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
dropout_rate: The dropout rate to apply to the attention (default: 0.1)
"""
def __init__(self, clusters, iterations=10, bits=32,
hash_bias=True, softmax_temp=None, dropout_rate=0.1):
super(ClusteredAttention, self).__init__()
self.clusters = clusters
self.iterations = iterations
self.bits = bits
self.hash_bias = hash_bias
self.softmax_temp = softmax_temp
self.dropout = Dropout(dropout_rate)
def _create_query_groups(self, Q, query_lengths):
N, H, L, E = Q.shape
# Compute the hashes for all the queries
planes = Q.new_empty((self.bits, E+1))
normal_(planes)
if not self.hash_bias:
planes[:, -1] = 0
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
# Cluster the hashes and return the cluster index per query
groups = cluster(
hashes,
query_lengths._lengths.int(),
clusters=self.clusters,
iterations=self.iterations,
bits=self.bits
)
return groups
def _group_queries(self, Q, groups):
"""Aggregate the Qs based on the index of cluster they belong to. Make
sure to allow for gradient propagation backwards from the grouped
queries to each query."""
q_grouped = _GroupQueries.apply(Q, *groups)
return q_grouped
def _broadcast_values(self, V, groups):
"""Broadcast the values back to the correct positions but make sure
that the gradient flows properly."""
V_new = _BroadcastValues.apply(V.contiguous(), *groups)
V_new = V_new.permute(0, 2, 1, 3).contiguous()
return V_new
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Make sure that there is no attention mask
assert attn_mask.all_ones, ("Clustered attention cannot use an "
"arbitrary attention mask.")
queries = queries.permute(0,2,1,3).contiguous()
keys = keys.permute(0,2,1,3).contiguous()
values = values.permute(0,2,1,3).contiguous()
N, H, L, E = queries.shape
softmax_temp = self.softmax_temp or 1./sqrt(E)
# Cluster the queries into groups
groups = self._create_query_groups(queries, query_lengths)
Q_grouped = self._group_queries(queries, groups)
# Compute the attention
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
QK = QK + key_lengths.additive_matrix[:, None, None, :]
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
V = torch.einsum("nhls,nhsd->nhld", A, values)
# Broadcast grouped attention
return self._broadcast_values(V, groups)
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <[email protected]>,
# Apoorv Vyas <[email protected]>
#
"""Implement a self attention that delegates to full attention or another
attention depending on the input sequence length."""
import torch
from torch.nn import Module
from .full_attention import FullAttention
class ConditionalFullAttention(Module):
""""Delegate to full attention if the input sequence is short.
Arguments
---------
other_attention: Use the passed attention module if the sequence is
longer than 'length_limit'.
length_limit: An integer denoting the maximum sequence length to
consider.
softmax_temp: See fast_transformers.attention.full_attention.
dropout_rate: See fast_transformers.attention.full_attention.
"""
def __init__(self, other_attention, length_limit=512, softmax_temp=None,
dropout_rate=0.1):
super(ConditionalFullAttention, self).__init__()
self.full_attention = FullAttention(softmax_temp, dropout_rate)
self.other_attention = other_attention
self.length_limit = length_limit
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Extract some shapes to compare with the length limit
L = queries.shape[1]
S = values.shape[1]
if L > self.length_limit or S > self.length_limit:
return self.other_attention(queries, keys, values, attn_mask,
query_lengths, key_lengths)
else:
return self.full_attention(queries, keys, values, attn_mask,
query_lengths, key_lengths)
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <[email protected]>,
# Apoorv Vyas <[email protected]>
#
"""Implement the oracle top-k attention. The top-k keys are exact ones.
MultiHeadAttention module. Note that this module is to be used in conjuction
with the AttentionLayer in order to work."""
from math import sqrt
import torch
from torch.nn import Dropout, Module
class ExactTopKAttention(Module):
"""Implement the oracle top-k softmax attention.
Arguments
---------
top-k: The top k keys to attend to (default: 32)
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
dropout_rate: The dropout rate to apply to the attention (default: 0.1)
"""
def __init__(self, topk=32, softmax_temp=None, dropout_rate=0.1):
super(ExactTopKAttention, self).__init__()
self.topk = topk
self.softmax_temp = softmax_temp
self.dropout = Dropout(dropout_rate)
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Extract some shapes and compute the temperature
N, L, H, E = queries.shape
_, S, _, D = values.shape
softmax_temp = self.softmax_temp or 1./sqrt(E)
# Compute the unnormalized attention and apply the masks
QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
topk = min(self.topk, L)
if not attn_mask.all_ones:
QK = QK + attn_mask.additive_matrix
QK = QK + key_lengths.additive_matrix[:, None, None]
topk_values, topk_idx = torch.topk(QK, topk, sorted=False, dim=-1)
mask = QK.new_ones(QK.shape) * float("-inf")
mask[
torch.arange(N, device=QK.device).view(N, 1, 1, 1),
torch.arange(H, device=QK.device).view(1, H, 1, 1),
torch.arange(L, device=QK.device).view(1, 1, L, 1),
topk_idx,
] = 0.
QK = QK + mask
# Compute the attention and the weighted average
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
V = torch.einsum("nhls,nshd->nlhd", A, values)
# Make sure that what we return is contiguous
return V.contiguous()
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <[email protected]>,
# Apoorv Vyas <[email protected]>
#
"""Implement the full attention similar to the one implemented by PyTorch's
MultiHeadAttention module. Note that this module is to be used in conjuction
with the `fast_transformers.attention.attention_layer.AttentionLayer` in order
to work."""
from math import sqrt
import torch
from torch.nn import Dropout, Module
class FullAttention(Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
dropout_rate: The dropout rate to apply to the attention (default: 0.1)
"""
def __init__(self, softmax_temp=None, dropout_rate=0.1):
super(FullAttention, self).__init__()
self.softmax_temp = softmax_temp
self.dropout = Dropout(dropout_rate)
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
"""Implements the multihead softmax attention.
Arguments
---------
queries: (N, L, H, E) The tensor containing the queries
keys: (N, S, H, E) The tensor containing the keys
values: (N, S, H, D) The tensor containing the values
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
query_lengths: An implementation of BaseMask that encodes how
many queries each sequence in the batch consists of
key_lengths: An implementation of BaseMask that encodes how
many queries each sequence in the batch consists of
"""
# Extract some shapes and compute the temperature
N, L, H, E = queries.shape
_, S, _, D = values.shape
softmax_temp = self.softmax_temp or 1./sqrt(E)
# Compute the unnormalized attention and apply the masks
QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
if not attn_mask.all_ones:
QK = QK + attn_mask.additive_matrix
QK = QK + key_lengths.additive_matrix[:, None, None]
# Compute the attention and the weighted average
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
V = torch.einsum("nhls,nshd->nlhd", A, values)
# Make sure that what we return is contiguous
return V.contiguous()
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <[email protected]>,
# Apoorv Vyas <[email protected]>
#
"""Implement improved clustered self attention."""
from math import sqrt
import torch
import torch.autograd
from torch.nn import Dropout, Module
from torch.nn.init import normal_
from ..masking import FullMask
from ..aggregate import aggregate, broadcast
from ..clustering.hamming import cluster
from ..hashing import compute_hashes
from ..sparse_product import sparse_dot_product, sparse_weighted_average
from ..sparse_product import clustered_sparse_dot_product, \
clustered_sparse_weighted_average
class _GroupQueries(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, clusters, counts):
factors = 1/counts.float()
q_grouped = aggregate(Q, clusters, factors)
ctx.save_for_backward(clusters, factors)
return q_grouped
@staticmethod
def backward(ctx, grad_q_grouped):
clusters, factors = ctx.saved_tensors
grad_q = broadcast(grad_q_grouped, clusters, factors)
return grad_q, None, None
class _BroadcastValues(torch.autograd.Function):
@staticmethod
def forward(ctx, v_grouped, clusters, counts):
factors = torch.ones_like(counts, dtype=v_grouped.dtype)
V = broadcast(v_grouped, clusters, factors)
ctx.save_for_backward(clusters, factors)
return V
@staticmethod
def backward(ctx, grad_v):
clusters, factors = ctx.saved_tensors
grad_v_grouped = aggregate(grad_v, clusters, factors)
return grad_v_grouped, None, None
class ImprovedClusteredAttention(Module):
"""
Immproved clustered attention approximation by recompution attention
for each query with the top-k keys for the corresponding cluster.
Given the queries, keys, and values as Q, K, and V respectively, we
first cluster the queries in "C" groups and compute the "C" query centroids
Q_c.
We now use to the centroids Q_c to identify the top-k keys with highest
dot products.
Subsequently, for each query we compute the sparse dot product with
the corresponding top-k keys to improve the attention approximation.
Arguments
---------
clusters: How many clusters to group the queries into
iterations: The number of lloyd iterations to perform (default: 10)
bits: How many bits to use for the hash (default: 32)
hash_bias: If true, hamming distance proportional to L2 distance
If false, hamming distance proportional to cosine distance
(default: True)
topk: Number of top-k keys to for improved approximation (default: 32)
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
dropout_rate: The dropout rate to apply to the attention (default: 0.1)
"""
def __init__(self, clusters, iterations=10, bits=32,
hash_bias=True, topk=32, softmax_temp=None, dropout_rate=0.1):
super(ImprovedClusteredAttention, self).__init__()
self.clusters = clusters
self.iterations = iterations
self.bits = bits
self.hash_bias = hash_bias
self.topk = topk
self.softmax_temp = softmax_temp
self.dropout = Dropout(dropout_rate)
def _create_query_groups(self, Q, query_lengths):
N, H, L, E = Q.shape
# Compute the hashes for all the queries
planes = Q.new_empty((self.bits, E+1))
normal_(planes)
if not self.hash_bias:
planes[:, -1] = 0
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
# Cluster the hashes and return the cluster index per query
clusters, counts = cluster(
hashes,
query_lengths._lengths.int(),
clusters=self.clusters,
iterations=self.iterations,
bits=self.bits
)
return clusters, counts
def _topk_attention(self, Q, K, V,
clusters, counts,
topk, topk_values,
A_bottomk, softmax_temp,
query_lengths):
"""Return the attention with just the topk heads."""
# Extract some indices
N, H, L, E = Q.shape
_, _, S, _ = K.shape
_, _, C, k = topk.shape
# We need to pass the output tensor to initialize to 0
QK = clustered_sparse_dot_product(
Q, K, topk,
clusters, counts,
query_lengths._lengths.int()
)
# We need to mask the topk dot products if topk > input_length
QK = QK.masked_fill(
torch.isinf(topk_values[:,0,0,:]).view(N, 1, 1, k),
float("-inf")
)
A = torch.softmax(softmax_temp * QK, dim=-1)
assert A_bottomk.is_contiguous()
A_bottomk = broadcast(
A_bottomk.unsqueeze(3),
clusters,
torch.ones_like(counts, dtype=torch.float32)
)
A = A * (1.0 - A_bottomk)
A = self.dropout(A)
assert A.is_contiguous()
V_new = clustered_sparse_weighted_average(A, V, topk, clusters)
return V_new
def _broadcast_values(self, V, clusters, counts):
"""Broadcast the values back to the correct positions but make sure
that the gradient flows properly."""
V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts)
return V_new
def _bottomk_attention(self, QK, V, clusters, counts, topk, softmax_temp):
"""Return the attention with just the bottomk keys."""
N, H, C, S = QK.shape
A = torch.softmax(softmax_temp * QK, dim=-1)
mask = QK.new_ones(QK.shape)
mask[
torch.arange(N, device=QK.device).view(N, 1, 1, 1),
torch.arange(H, device=QK.device).view(1, H, 1, 1),
torch.arange(C, device=QK.device).view(1, 1, C, 1),
topk,
] = 0
A = A * mask
A_bottomk = A.sum(-1)
A = self.dropout(A)
# Compute the values
V_new = torch.einsum("nhls,nhse->nhle", A, V)
# Broadcast the values back depending on the groups
V_new = self._broadcast_values(V_new, clusters, counts)
return V_new, A_bottomk
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Make sure that there is no attention mask
assert attn_mask.all_ones, ("Improved-clustered attention cannot "
"use an arbitrary attention mask.")
queries = queries.permute(0,2,1,3).contiguous()
keys = keys.permute(0,2,1,3).contiguous()
values = values.permute(0,2,1,3).contiguous()
N, H, L, E = queries.shape
softmax_temp = self.softmax_temp or 1./sqrt(E)
# Cluster the queries into groups
clusters, counts = self._create_query_groups(queries, query_lengths)
Q_grouped = _GroupQueries.apply(queries, clusters, counts)
# Compute the attention
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
QK = QK + key_lengths.additive_matrix[:, None, None, :]
topk_values, topk = torch.topk(QK, self.topk, sorted=False, dim=-1)
assert topk.is_contiguous()
# Now compute the attention with only the bottom keys
V_bottomk, A_bottomk = self._bottomk_attention(
QK, values,
clusters, counts,
topk,
softmax_temp
)
# Now compute the attention with only the top keys
V_topk = self._topk_attention(
queries, keys, values,
clusters, counts,
topk, topk_values,
A_bottomk,
softmax_temp,
query_lengths
)
V_new = V_topk + V_bottomk
return V_new.permute(0, 2, 1, 3).contiguous()
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <[email protected]>,
# Apoorv Vyas <[email protected]>
#
"""Implement improved clustered causal self attention."""
from math import sqrt
import torch
import torch.autograd
from torch.nn import Dropout, Module
from torch.nn.init import normal_
from ..masking import FullMask
from ..aggregate import aggregate, broadcast
from ..clustering.hamming import cluster
from ..hashing import compute_hashes
from ..sparse_product import sparse_dot_product, sparse_weighted_average
from ..sparse_product import clustered_sparse_dot_product, \
clustered_sparse_weighted_average
class _GroupQueries(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, clusters, counts):
factors = 1/counts.float()
q_grouped = aggregate(Q, clusters, factors)
ctx.save_for_backward(clusters, factors)
return q_grouped
@staticmethod
def backward(ctx, grad_q_grouped):
clusters, factors = ctx.saved_tensors
grad_q = broadcast(grad_q_grouped, clusters, factors)
return grad_q, None, None
class _BroadcastValues(torch.autograd.Function):
@staticmethod
def forward(ctx, v_grouped, clusters, counts):
factors = torch.ones_like(counts, dtype=v_grouped.dtype)
V = broadcast(v_grouped, clusters, factors)
ctx.save_for_backward(clusters, factors)
return V
@staticmethod
def backward(ctx, grad_v):
clusters, factors = ctx.saved_tensors
grad_v_grouped = aggregate(grad_v, clusters, factors)
return grad_v_grouped, None, None
class ImprovedClusteredCausalAttention(Module):
"""
Immproved clustered causal attention approximation by recomputing attention
for each query with the top-k keys for the corresponding cluster.
Given the queries, keys, and values as Q, K, and V respectively, we
first cluster the queries in "C" groups and compute the "C" query centroids
Q_c.
We now use to the centroids Q_c to identify the top-k keys with highest
dot products.
Subsequently, for each query we compute the sparse dot product with
the corresponding top-k keys to improve the attention approximation.
Key difference with improved clustered attention is that we only use
top-k keys with causal mask, we do not compute attention on the
bottom-k keys.
Arguments
---------
clusters: How many clusters to group the queries into
iterations: The number of lloyd iterations to perform (default: 10)
bits: How many bits to use for the hash (default: 32)
hash_bias: If true, hamming distance proportional to L2 distance
If false, hamming distance proportional to cosine distance
(default: True)
topk: Number of top-k keys to for improved approximation (default: 32)
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
dropout_rate: The dropout rate to apply to the attention (default: 0.1)
"""
def __init__(self, clusters, iterations=10, bits=32,
hash_bias=True, topk=32, softmax_temp=None, dropout_rate=0.1):
super(ImprovedClusteredCausalAttention, self).__init__()
self.clusters = clusters
self.iterations = iterations
self.bits = bits
self.hash_bias = hash_bias
self.topk = topk
self.softmax_temp = softmax_temp
self.dropout = Dropout(dropout_rate)
def _create_query_groups(self, Q, query_lengths):
N, H, L, E = Q.shape
# Compute the hashes for all the queries
planes = Q.new_empty((self.bits, E+1))
normal_(planes)
if not self.hash_bias:
planes[:, -1] = 0
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
# Cluster the hashes and return the cluster index per query
clusters, counts = cluster(
hashes,
query_lengths._lengths.int(),
clusters=self.clusters,
iterations=self.iterations,
bits=self.bits
)
return clusters, counts
def _topk_attention(self, Q, K, V,
clusters, counts,
topk, topk_values,
softmax_temp,
query_lengths):
"""Return the attention with just the topk heads."""
# Extract some indices
N, H, L, E = Q.shape
_, _, S, _ = K.shape
_, _, C, k = topk.shape
# We need to pass the output tensor to initialize to 0
QK = clustered_sparse_dot_product(
Q, K, topk,
clusters, counts,
query_lengths._lengths.int()
)
# We need to mask the topk dot products if topk > input_length
QK = QK.masked_fill(
torch.isinf(topk_values[:,0,0,:]).view(N, 1, 1, k),
float("-inf")
)
# We need to mask out the future
assert topk.is_contiguous()
topk_broadcast = broadcast(
topk.float(),
clusters,
torch.ones_like(counts, dtype=torch.float32)
)
QK = QK.masked_fill(
topk_broadcast.long() > torch.arange(L, device=QK.device).view(1, 1, L, 1),
float("-1e7")
)
A = torch.softmax(softmax_temp * QK, dim=-1)
A = self.dropout(A)
assert A.is_contiguous()
V_new = clustered_sparse_weighted_average(A, V, topk, clusters)
return V_new
def _broadcast_values(self, V, clusters, counts):
"""Broadcast the values back to the correct positions but make sure
that the gradient flows properly."""
V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts)
return V_new
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Apply the key padding mask and make sure the attn_mask is a
# lower triangular causal mask
if not attn_mask.lower_triangular:
raise RuntimeError(("ImprovedClusteredCausalAttention only supports full "
"lower triangular masks"))
queries = queries.permute(0,2,1,3).contiguous()
keys = keys.permute(0,2,1,3).contiguous()
values = values.permute(0,2,1,3).contiguous()
N, H, L, E = queries.shape
softmax_temp = self.softmax_temp or 1./sqrt(E)
# Cluster the queries into groups
clusters, counts = self._create_query_groups(queries, query_lengths)
Q_grouped = _GroupQueries.apply(queries, clusters, counts)
# Compute the attention
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
QK = QK + key_lengths.additive_matrix[:, None, None, :]
topk_values, topk = torch.topk(QK, self.topk, sorted=False, dim=-1)
assert topk.is_contiguous()
# Compute the attention with only the top keys
V_topk = self._topk_attention(
queries, keys, values,
clusters, counts,
topk, topk_values,
softmax_temp,
query_lengths
)
V_new = V_topk
return V_new.permute(0, 2, 1, 3).contiguous()
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <[email protected]>,
# Apoorv Vyas <[email protected]>
#
"""Implement unmasked linear attention."""
import torch
from torch.nn import Module
def elu_feature_map(x):
return torch.nn.functional.elu(x) + 1
class LinearAttention(Module):
"""Implement unmasked attention using dot product of feature maps in
O(N D^2) complexity.
Given the queries, keys and values as Q, K, V instead of computing
V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
we make use of a feature map function Φ(.) and perform the following
computation
V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
The above can be computed in O(N D^2) complexity where D is the
dimensionality of Q, K and V and N is the sequence length. Depending on the
feature map, however, the complexity of the attention might be limited.
Arguments
---------
feature_map: callable, a callable that applies the feature map to the
last dimension of a tensor (default: elu(x)+1)
eps: float, a small number to ensure the numerical stability of the
denominator (default: 1e-6)
"""
def __init__(self, feature_map=None, eps=1e-6):
super(LinearAttention, self).__init__()
self.feature_map = feature_map or elu_feature_map
self.eps = eps
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Apply the feature map to the queries and keys
Q = self.feature_map(queries)
K = self.feature_map(keys)
# Apply the key padding mask and make sure that the attn_mask is
# all_ones
if not attn_mask.all_ones:
raise RuntimeError(("LinearAttention does not support arbitrary "
"attention masks"))
K = K * key_lengths.float_matrix[:, :, None, None]
# Compute the KV matrix, namely the dot product of keys and values so
# that we never explicitly compute the attention matrix and thus
# decrease the complexity
KV = torch.einsum("nshd,nshm->nhmd", K, values)
# Compute the normalizer
Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)
# Finally compute and return the new values
V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)
return V.contiguous()
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <[email protected]>,
# Apoorv Vyas <[email protected]>
#
"""Implement the Reformer attention from the paper
"Reformer the efficient transformer"."""
from math import sqrt
import torch
from torch.nn import Dropout, Module
from torch.nn.init import normal_
from ..masking import FullMask
class ReformerAttention(Module):
"""Implement the attention module of the paper "Reformer the efficient
transformer"
Arguments
---------
chunk_size : Chunk size for each block (default: 32)
bits : Number of bits for hashing (default: 8)
rounds : Number of rounds of attention computation (default: 4)
masked : If true, the query does not attend to itsself (default: False)
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
dropout_rate: The dropout rate to apply to the attention (default: 0.1)
"""
def __init__(self, chunk_size=32, bits=8, rounds=4, masked=False,
softmax_temp=None, dropout_rate=0.1):
super(ReformerAttention, self).__init__()
self.chunk_size = chunk_size
self.bits = bits
self.rounds = rounds
self.masked = masked
self.softmax_temp = softmax_temp
self.dropout = Dropout(dropout_rate)
def _normalize(self, x):
norms = torch.sqrt(torch.einsum("nlhe,nlhe->nlh", x, x))
x_normed = x / norms.unsqueeze(-1)
return x_normed
def _look_back(self, x):
xshape = x.shape
return torch.cat([
x.new_zeros((xshape[0], 1) + xshape[2:]),
torch.repeat_interleave(x, 2, dim=1)[:,:-1]
], dim=1).view(xshape[0], xshape[1], 2*xshape[2], *xshape[3:])
def _reformer_round(self, Q, K, V, mask, softmax_temp):
# Hash the queries
N, L, H, E = Q.shape
planes = Q.new_empty(self.bits, E)
normal_(planes)
projected = torch.einsum("nlhe,be->nlhb", K, planes)
hashes = torch.argmax(
torch.cat([projected, -projected], dim=-1),
dim=-1
)
# Sort the queries in order to group them
group = torch.argsort(hashes, dim=1)
invert_group = torch.empty_like(group)
batch_indices = torch.arange(N, device=hashes.device).view(N, 1, 1)
sequence_indices = torch.arange(L, device=hashes.device).view(1, L, 1)
head_indices = torch.arange(H, device=hashes.device).view(1, 1, H)
invert_group[batch_indices, group, head_indices] = sequence_indices
group = group.view(N, -1, self.chunk_size, H)
invert_group = invert_group.view(N, -1, self.chunk_size, H)
batch_indices = batch_indices.unsqueeze(1)
head_indices = head_indices.unsqueeze(0)
# Reorder Q, V and mask
Q_grouped = Q[batch_indices, group, head_indices]
K_grouped = K[batch_indices, group, head_indices]
V_grouped = V[batch_indices, group, head_indices]
mask_grouped = mask[
batch_indices.unsqueeze(1),
group.unsqueeze(3),
self._look_back(group).unsqueeze(2)
]
mask_grouped[:, 0, :, :Q_grouped.shape[2]] = float("-inf")
# When everything is masked just unmask everything because it doesn't
# matter what the output is at those positions
# This is to avoid inf/nans in the new values at masked positions
infmask = torch.isinf(mask_grouped)
infmask = torch.all(infmask, dim=3, keepdims=True)
mask_grouped = mask_grouped.masked_fill(infmask, 0.)
# Attention
K_grouped = self._look_back(K_grouped)
QQ = torch.einsum("nblhe,nbshe->nbhls", Q_grouped, K_grouped)
QQ = QQ + mask_grouped.permute(0, 1, 4, 2, 3)
A = torch.softmax(softmax_temp * QQ, dim=-1)
A = self.dropout(A)
# Values
V_grouped = self._look_back(V_grouped)
V_new = torch.einsum("nbhls,nbshe->nblhe", A, V_grouped)
V_new = V_new.contiguous().view(N, -1, H, E)
V_new = V_new[batch_indices, invert_group, head_indices]
V_new = V_new.contiguous().view(N, L, H, E)
return V_new
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Extract the dimensions of query, key, value
N, L, H, E = queries.shape
softmax_temp = self.softmax_temp or 1./sqrt(E)
# Create the mask
mask = key_lengths.additive_matrix.unsqueeze(1).expand(N, L, L)
if self.masked:
mask = mask + torch.eye(L, device=queries.device).unsqueeze(0)*float(-1e9)
if not attn_mask.all_ones:
mask = mask + attn_mask.additive_matrix.unsqueeze(0)
# Get normalized Queries as Keys
K = self._normalize(queries)
# Zero the masked out keys
K = K * key_lengths.float_matrix.view(N, L, 1, 1)
V_new = 0
factor = 1/self.rounds
for i in range(self.rounds):
V_new = V_new + \
factor * self._reformer_round(queries, K, values, mask, softmax_temp)
return V_new
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment