Created
November 21, 2025 21:18
-
-
Save LiutongZhou/92333a773dd54c213c242f0bea66492c to your computer and use it in GitHub Desktop.
FlashAttention
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """FlashAttention: reference implementation of the core logic""" | |
| import math | |
| import torch | |
| from einops import einsum | |
| from jaxtyping import Float, Int | |
| from torch import nn, Tensor | |
| class FlashAttention(nn.Module): | |
| """FlashAttention forward pass with sequence tiling and streaming softmax | |
| Parameters | |
| ---------- | |
| query_block_size : int, default=128 | |
| Maximum number of query tokens processed per tile (height of a Q-tile). | |
| kv_block_size : int, default=128 | |
| Maximum number of key/value tokens processed per tile (width of a KV-tile). | |
| Notes | |
| ----- | |
| The implementation is mathematically correct and numerically stable. | |
| It is written for education, not for production throughput. | |
| """ | |
| def __init__(self, query_block_size: int = 128, kv_block_size: int = 128): | |
| super().__init__() | |
| self.query_block_size = int(query_block_size) | |
| self.kv_block_size = int(kv_block_size) | |
| def forward( | |
| self, | |
| q: Float[Tensor, "batch num_heads seq_len_q head_dim"], | |
| k: Float[Tensor, "batch num_heads seq_len_kv head_dim"], | |
| v: Float[Tensor, "batch num_heads seq_len_kv value_dim"], | |
| causal: bool = False, | |
| ) -> Float[Tensor, "batch num_heads seq_len_q value_dim"]: | |
| """FlashAttention forward pass (exact, tiled, streaming softmax). | |
| Notes | |
| ----- | |
| The computation is tiled over the sequence dimension: | |
| - Outer loop: query tiles of size at most `query_block_size`. | |
| - Inner loop: KV tiles of size at most `kv_block_size`. | |
| For each query tile, the algorithm streams over all KV tiles using the | |
| streaming-softmax recurrence to accumulate the exact numerator and | |
| denominator, then normalizes at the end. | |
| """ | |
| batch, num_heads, seq_len_q, head_dim = q.shape | |
| assert (num_heads_kv := k.shape[1]) == v.shape[1] | |
| if num_heads_kv < num_heads: | |
| # handle multi-query attention | |
| assert num_heads % num_heads_kv == 0, "num_heads must be a multiple of num_heads_kv" | |
| k = torch.repeat_interleave(k, repeats=num_heads // num_heads_kv, dim=1) | |
| v = torch.repeat_interleave(v, repeats=num_heads // num_heads_kv, dim=1) | |
| seq_len_kv, value_dim = v.shape[-2:] | |
| scale = 1.0 / math.sqrt(head_dim) | |
| out: Float[Tensor, "batch num_heads seq_len_q value_dim"] = torch.empty( | |
| (batch, num_heads, seq_len_q, value_dim), dtype=v.dtype, device=v.device | |
| ) | |
| if causal: | |
| positions: Int[Tensor, "seq_len_kv"] = torch.arange(seq_len_kv, device=q.device) | |
| for q_block_start in range(0, seq_len_q, self.query_block_size): | |
| # slice a query tile | |
| q_block_end = min(q_block_start + self.query_block_size, seq_len_q) | |
| q_block: Float[Tensor, "batch num_heads seq_len_q_block head_dim"] | |
| q_block = q[..., q_block_start:q_block_end, :] | |
| # region initialize streaming softmax states for a query tile | |
| batch_q_block, num_heads, seq_len_q_block, _ = q_block.shape | |
| # Running block-wise maximum of logits. | |
| running_max: Float[Tensor, "batch_q_block num_heads seq_len_q_block"] | |
| running_max = torch.full( | |
| (batch_q_block, num_heads, seq_len_q_block), | |
| float("-inf"), | |
| dtype=q_block.dtype, | |
| device=q_block.device, | |
| ) | |
| # Running block-wise sum of exp(logits - running_max). | |
| running_denominator: Float[Tensor, "batch_q_block num_heads seq_len_q_block"] | |
| running_denominator = torch.zeros( | |
| (batch_q_block, num_heads, seq_len_q_block), | |
| dtype=q_block.dtype, | |
| device=q_block.device, | |
| ) | |
| # Running block-wise sum of exp(logits - running_max) @ V. | |
| running_numerator: Float[Tensor, "batch_q_block num_heads seq_len_q_block value_dim"] | |
| running_numerator = torch.zeros( | |
| (batch_q_block, num_heads, seq_len_q_block, value_dim), | |
| dtype=q_block.dtype, | |
| device=q_block.device, | |
| ) | |
| # endregion | |
| for kv_block_start in range(0, seq_len_kv, self.kv_block_size): | |
| kv_block_end = min(kv_block_start + self.kv_block_size, seq_len_kv) | |
| k_block: Float[Tensor, "batch num_heads seq_len_kv_block head_dim"] | |
| k_block = k[..., kv_block_start:kv_block_end, :] | |
| v_block: Float[Tensor, "batch num_heads seq_len_kv_block value_dim"] | |
| v_block = v[..., kv_block_start:kv_block_end, :] | |
| logits_block: Float[Tensor, "batch num_heads seq_len_q_block seq_len_kv_block"] | |
| logits_block = einsum( | |
| q_block * math.sqrt(scale), | |
| k_block * math.sqrt(scale), | |
| "b h q d, b h k d -> b h q k", | |
| ) | |
| if causal: | |
| # mask[q, k] = True when query position < k position -> upper triangular True | |
| kv_block_positions: Int[Tensor, "kv_block_size"] = positions[kv_block_start:kv_block_end] | |
| q_block_positions: Int[Tensor, "query_block_size"] = positions[ | |
| q_block_start - seq_len_q + seq_len_kv : q_block_end - seq_len_q + seq_len_kv | |
| ] | |
| mask_qk: Tensor = q_block_positions.reshape(-1, 1) < kv_block_positions.reshape(1, -1) | |
| logits_block = logits_block.masked_fill(mask_qk[None, None, ...], float("-inf")) | |
| running_max, running_denominator, running_numerator = self._streaming_softmax_update_step( | |
| running_max=running_max, | |
| running_denominator=running_denominator, | |
| running_numerator=running_numerator, | |
| logits_block=logits_block, | |
| value_block=v_block, | |
| ) | |
| out_block: Float[Tensor, "batch num_heads seq_len_q_block value_dim"] | |
| out_block = running_numerator / running_denominator.unsqueeze(-1) | |
| out[..., q_block_start:q_block_end, :] = out_block | |
| return out | |
| @staticmethod | |
| def _streaming_softmax_update_step( | |
| running_max: Float[Tensor, "batch num_heads seq_len_q_block"], | |
| running_denominator: Float[Tensor, "batch num_heads seq_len_q_block"], | |
| running_numerator: Float[Tensor, "batch num_heads seq_len_q_block value_dim"], | |
| logits_block: Float[Tensor, "batch num_heads seq_len_q_block seq_len_kv_block"], | |
| value_block: Float[Tensor, "batch num_heads seq_len_kv_block value_dim"], | |
| ) -> tuple[ | |
| Float[Tensor, "batch num_heads seq_len_q_block"], | |
| Float[Tensor, "batch num_heads seq_len_q_block"], | |
| Float[Tensor, "batch num_heads seq_len_q_block value_dim"], | |
| ]: | |
| """Update streaming-softmax states over a (Q-tile × KV-tile) pair. | |
| Returns | |
| ------- | |
| new_running_max : Float[Tensor, "batch num_heads seq_len_q_block"] | |
| new_running_denominator : Float[Tensor, "batch num_heads seq_len_q_block"] | |
| new_running_numerator : Float[Tensor, "batch num_heads seq_len_q_block value_dim"] | |
| """ | |
| block_wise_max: Float[Tensor, "batch num_heads seq_len_q_block"] = logits_block.amax(dim=-1) | |
| new_running_max: Float[Tensor, "batch num_heads seq_len_q_block"] = torch.maximum( | |
| running_max, block_wise_max | |
| ) | |
| rescale: Float[Tensor, "batch num_heads seq_len_q_block"] = (running_max - new_running_max).exp() | |
| block_wise_exp_logits: Float[Tensor, "batch num_heads seq_len_q_block seq_len_kv_block"] = ( | |
| logits_block - new_running_max.unsqueeze(-1) | |
| ).exp() | |
| new_running_denominator: Float[Tensor, "batch num_heads seq_len_q_block"] = ( | |
| running_denominator * rescale + block_wise_exp_logits.sum(dim=-1) | |
| ) | |
| new_running_numerator: Float[Tensor, "batch num_heads seq_len_q_block value_dim"] | |
| new_running_numerator = running_numerator * rescale.unsqueeze(-1) + einsum( | |
| block_wise_exp_logits, | |
| value_block, | |
| "b h q v, b h v dv -> b h q dv", | |
| ) | |
| return new_running_max, new_running_denominator, new_running_numerator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment