Skip to content

Instantly share code, notes, and snippets.

@Warvito
Created December 14, 2022 20:18
Show Gist options
  • Save Warvito/5c3363ddbf3941150c2511b27b75d701 to your computer and use it in GitHub Desktop.
Save Warvito/5c3363ddbf3941150c2511b27b75d701 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from monai.utils import optional_import
xformers, has_xformers = optional_import("xformers", name="xformers")
class SelfAttentionBlock(nn.Module):
def __init__(
self,
query_dim: int,
num_attention_heads: int = 8,
num_head_channels: int = 64,
dropout: float = 0.0,
) -> None:
super().__init__()
inner_dim = num_head_channels * num_attention_heads
self.scale = num_head_channels**-0.5
self.heads = num_attention_heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(query_dim, inner_dim, bias=False)
self.to_v = nn.Linear(query_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, dim = x.shape
head_size = self.heads
x = x.reshape(batch_size, seq_len, head_size, dim // head_size)
x = x.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return x
def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, dim = x.shape
head_size = self.heads
x = x.reshape(batch_size // head_size, head_size, seq_len, dim)
x = x.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return x
def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
attention_scores = torch.matmul(query, key.transpose(-1, -2))
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _memory_efficient_attention_xformers(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
x = self.reshape_batch_dim_to_heads(x)
x = x.to(query.dtype)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if has_xformers:
x = self._memory_efficient_attention_xformers(query, key, value)
else:
x = self._attention(query, key, value)
return self.to_out(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment