Created
August 8, 2024 15:15
-
-
Save cccntu/f461ec7360273ecb8d4b5bc463be8952 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Generates a document causal attention mask based on a document ID tensor""" | |
from typing import List, Union | |
import torch | |
from torch import Tensor | |
from torch.nn.attention.flex_attention import _mask_mod_signature, or_masks | |
from attn_gym.masks import causal_mask | |
def _offsets_to_doc_ids_tensor(offsets): | |
device = offsets.device | |
counts = offsets[1:] - offsets[:-1] | |
return torch.repeat_interleave( | |
torch.arange(len(counts), device=device, dtype=torch.int32), counts | |
) | |
def length_to_offsets(lengths: List[int], device: Union[str, torch.device]) -> Tensor: | |
"""Converts a list of lengths to a list of offsets. | |
Args: | |
lengths: A list of lengths. | |
""" | |
offsets = [0] | |
offsets.extend(lengths) | |
offsets = torch.tensor(offsets, device=device, dtype=torch.int32) | |
offsets = torch.cumsum(offsets, dim=-1) | |
return offsets | |
def generate_doc_mask_mod(mask_mod: _mask_mod_signature, offsets: Tensor) -> _mask_mod_signature: | |
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked | |
format. | |
Args: | |
mask_mod: The mask mod to apply to the documents | |
offsets: This tensor should be of shape(num_documents + 1) | |
this should contain the cumulative counts of document tokens. | |
e.g. if you have 3 documents of length 2, 4, 3 then | |
offsets = [0, 2, 6, 9] | |
Note: | |
What is the sequence stacked format? When assembling batches of inputs, we | |
take multiple sequences and stack them together to form 1 large sequence. We then | |
use masking to ensure that the attention scores are only applied to tokens within | |
the same document. | |
""" | |
document_id = _offsets_to_doc_ids_tensor(offsets) | |
def doc_mask_mod(b, h, q_idx, kv_idx): | |
same_doc = document_id[q_idx] == document_id[kv_idx] | |
q_logical = q_idx - offsets[document_id[q_idx]] | |
kv_logical = kv_idx - offsets[document_id[kv_idx]] | |
inner_mask = mask_mod(b, h, q_logical, kv_logical) | |
return same_doc & inner_mask | |
return doc_mask_mod | |
def generate_turn_mask_mod(offsets: Tensor, is_user: Tensor) -> _mask_mod_signature: | |
turn_id = _offsets_to_doc_ids_tensor(offsets) | |
def doc_mask_mod(b, h, q_idx, kv_idx): | |
same_turn = turn_id[q_idx] == turn_id[kv_idx] | |
return same_turn & is_user[turn_id[q_idx]] | |
return doc_mask_mod | |
def main(device: str = "cpu"): | |
"""Visualize the attention scores of document causal mask mod. | |
Args: | |
device (str): Device to use for computation. Defaults to "cpu". | |
""" | |
from attn_gym import visualize_attention_scores | |
import random | |
random.seed(0) | |
def generate_random_lengths(total_length, num_documents): | |
# Initialize all lengths to 1 to ensure each document has at least one token | |
lengths = [1] * num_documents | |
remaining_length = total_length - num_documents | |
# Randomly distribute the remaining length | |
for _ in range(remaining_length): | |
index = random.randint(0, num_documents - 1) | |
lengths[index] += 1 | |
return lengths | |
max_seq_len, doc_count, turn_count = 40, 2, 10 | |
B, H, SEQ_LEN, HEAD_DIM = 1, 1, max_seq_len, 8 | |
turn_lengths = generate_random_lengths(max_seq_len, turn_count) | |
doc_turn_counts = generate_random_lengths(turn_count, doc_count) | |
doc_lengths = [] | |
turn_is_user = [] | |
i = 0 | |
for doc_id in range(doc_count): | |
doc_lengths.append(sum(turn_lengths[i:i + doc_turn_counts[doc_id]])) | |
turn_is_user.extend([True] + [True, False] * ((doc_turn_counts[doc_id] - 1) // 2) + [True] * ((doc_turn_counts[doc_id] - 1) % 2)) | |
i += doc_turn_counts[doc_id] | |
doc_offsets = length_to_offsets(doc_lengths, device) | |
turn_offsets = length_to_offsets(turn_lengths, device) | |
def make_tensor(): | |
return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device) | |
query, key = make_tensor(), make_tensor() | |
document_causal_mask = generate_doc_mask_mod(causal_mask, doc_offsets) | |
turn_is_user = torch.tensor(turn_is_user, device=device) | |
turn_mask = generate_turn_mask_mod(turn_offsets, turn_is_user) | |
chat_mask = or_masks(document_causal_mask, turn_mask) | |
visualize_attention_scores( | |
query, | |
key, | |
mask_mod=chat_mask, | |
device=device, | |
name="document_packing_chat_tuning_mask", | |
) | |
if __name__ == "__main__": | |
try: | |
from jsonargparse import CLI | |
except ImportError: | |
raise ImportError("Be sure to run: pip install -e .[viz]") | |
CLI(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment