Skip to content

Instantly share code, notes, and snippets.

@grey-area
Created November 24, 2022 16:17
Show Gist options
  • Save grey-area/50360631f4347035ec0ac0e0ebbeb5ee to your computer and use it in GitHub Desktop.
Save grey-area/50360631f4347035ec0ac0e0ebbeb5ee to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from tqdm import tqdm
def subsequent_mask(size):
return torch.triu(torch.full((size, size), float('-inf')), diagonal=1)
if __name__ == "__main__":
d_model = 512
transformer_encoder = nn.TransformerEncoderLayer(d_model=d_model, nhead=8)
# Initial state
x = torch.zeros(1, 1, d_model)
# Loop, concatenate
for i in tqdm(range(500)):
sequence_length = x.size(0)
src_mask = subsequent_mask(sequence_length)
# in iteration i - 1, sequence of length i attends to sequence of length i
# but we only want the last element to attend to the sequence
output = transformer_encoder(x, src_mask=src_mask)
last_output = output[-1:]
x = torch.cat([x, last_output], dim=0)
# Computation is cubic in sequence length instead of quadratic
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment