Skip to content

Instantly share code, notes, and snippets.

@LiutongZhou
Created November 21, 2025 21:18
Show Gist options
  • Select an option

  • Save LiutongZhou/92333a773dd54c213c242f0bea66492c to your computer and use it in GitHub Desktop.

Select an option

Save LiutongZhou/92333a773dd54c213c242f0bea66492c to your computer and use it in GitHub Desktop.
FlashAttention
"""FlashAttention: reference implementation of the core logic"""
import math
import torch
from einops import einsum
from jaxtyping import Float, Int
from torch import nn, Tensor
class FlashAttention(nn.Module):
"""FlashAttention forward pass with sequence tiling and streaming softmax
Parameters
----------
query_block_size : int, default=128
Maximum number of query tokens processed per tile (height of a Q-tile).
kv_block_size : int, default=128
Maximum number of key/value tokens processed per tile (width of a KV-tile).
Notes
-----
The implementation is mathematically correct and numerically stable.
It is written for education, not for production throughput.
"""
def __init__(self, query_block_size: int = 128, kv_block_size: int = 128):
super().__init__()
self.query_block_size = int(query_block_size)
self.kv_block_size = int(kv_block_size)
def forward(
self,
q: Float[Tensor, "batch num_heads seq_len_q head_dim"],
k: Float[Tensor, "batch num_heads seq_len_kv head_dim"],
v: Float[Tensor, "batch num_heads seq_len_kv value_dim"],
causal: bool = False,
) -> Float[Tensor, "batch num_heads seq_len_q value_dim"]:
"""FlashAttention forward pass (exact, tiled, streaming softmax).
Notes
-----
The computation is tiled over the sequence dimension:
- Outer loop: query tiles of size at most `query_block_size`.
- Inner loop: KV tiles of size at most `kv_block_size`.
For each query tile, the algorithm streams over all KV tiles using the
streaming-softmax recurrence to accumulate the exact numerator and
denominator, then normalizes at the end.
"""
batch, num_heads, seq_len_q, head_dim = q.shape
assert (num_heads_kv := k.shape[1]) == v.shape[1]
if num_heads_kv < num_heads:
# handle multi-query attention
assert num_heads % num_heads_kv == 0, "num_heads must be a multiple of num_heads_kv"
k = torch.repeat_interleave(k, repeats=num_heads // num_heads_kv, dim=1)
v = torch.repeat_interleave(v, repeats=num_heads // num_heads_kv, dim=1)
seq_len_kv, value_dim = v.shape[-2:]
scale = 1.0 / math.sqrt(head_dim)
out: Float[Tensor, "batch num_heads seq_len_q value_dim"] = torch.empty(
(batch, num_heads, seq_len_q, value_dim), dtype=v.dtype, device=v.device
)
if causal:
positions: Int[Tensor, "seq_len_kv"] = torch.arange(seq_len_kv, device=q.device)
for q_block_start in range(0, seq_len_q, self.query_block_size):
# slice a query tile
q_block_end = min(q_block_start + self.query_block_size, seq_len_q)
q_block: Float[Tensor, "batch num_heads seq_len_q_block head_dim"]
q_block = q[..., q_block_start:q_block_end, :]
# region initialize streaming softmax states for a query tile
batch_q_block, num_heads, seq_len_q_block, _ = q_block.shape
# Running block-wise maximum of logits.
running_max: Float[Tensor, "batch_q_block num_heads seq_len_q_block"]
running_max = torch.full(
(batch_q_block, num_heads, seq_len_q_block),
float("-inf"),
dtype=q_block.dtype,
device=q_block.device,
)
# Running block-wise sum of exp(logits - running_max).
running_denominator: Float[Tensor, "batch_q_block num_heads seq_len_q_block"]
running_denominator = torch.zeros(
(batch_q_block, num_heads, seq_len_q_block),
dtype=q_block.dtype,
device=q_block.device,
)
# Running block-wise sum of exp(logits - running_max) @ V.
running_numerator: Float[Tensor, "batch_q_block num_heads seq_len_q_block value_dim"]
running_numerator = torch.zeros(
(batch_q_block, num_heads, seq_len_q_block, value_dim),
dtype=q_block.dtype,
device=q_block.device,
)
# endregion
for kv_block_start in range(0, seq_len_kv, self.kv_block_size):
kv_block_end = min(kv_block_start + self.kv_block_size, seq_len_kv)
k_block: Float[Tensor, "batch num_heads seq_len_kv_block head_dim"]
k_block = k[..., kv_block_start:kv_block_end, :]
v_block: Float[Tensor, "batch num_heads seq_len_kv_block value_dim"]
v_block = v[..., kv_block_start:kv_block_end, :]
logits_block: Float[Tensor, "batch num_heads seq_len_q_block seq_len_kv_block"]
logits_block = einsum(
q_block * math.sqrt(scale),
k_block * math.sqrt(scale),
"b h q d, b h k d -> b h q k",
)
if causal:
# mask[q, k] = True when query position < k position -> upper triangular True
kv_block_positions: Int[Tensor, "kv_block_size"] = positions[kv_block_start:kv_block_end]
q_block_positions: Int[Tensor, "query_block_size"] = positions[
q_block_start - seq_len_q + seq_len_kv : q_block_end - seq_len_q + seq_len_kv
]
mask_qk: Tensor = q_block_positions.reshape(-1, 1) < kv_block_positions.reshape(1, -1)
logits_block = logits_block.masked_fill(mask_qk[None, None, ...], float("-inf"))
running_max, running_denominator, running_numerator = self._streaming_softmax_update_step(
running_max=running_max,
running_denominator=running_denominator,
running_numerator=running_numerator,
logits_block=logits_block,
value_block=v_block,
)
out_block: Float[Tensor, "batch num_heads seq_len_q_block value_dim"]
out_block = running_numerator / running_denominator.unsqueeze(-1)
out[..., q_block_start:q_block_end, :] = out_block
return out
@staticmethod
def _streaming_softmax_update_step(
running_max: Float[Tensor, "batch num_heads seq_len_q_block"],
running_denominator: Float[Tensor, "batch num_heads seq_len_q_block"],
running_numerator: Float[Tensor, "batch num_heads seq_len_q_block value_dim"],
logits_block: Float[Tensor, "batch num_heads seq_len_q_block seq_len_kv_block"],
value_block: Float[Tensor, "batch num_heads seq_len_kv_block value_dim"],
) -> tuple[
Float[Tensor, "batch num_heads seq_len_q_block"],
Float[Tensor, "batch num_heads seq_len_q_block"],
Float[Tensor, "batch num_heads seq_len_q_block value_dim"],
]:
"""Update streaming-softmax states over a (Q-tile × KV-tile) pair.
Returns
-------
new_running_max : Float[Tensor, "batch num_heads seq_len_q_block"]
new_running_denominator : Float[Tensor, "batch num_heads seq_len_q_block"]
new_running_numerator : Float[Tensor, "batch num_heads seq_len_q_block value_dim"]
"""
block_wise_max: Float[Tensor, "batch num_heads seq_len_q_block"] = logits_block.amax(dim=-1)
new_running_max: Float[Tensor, "batch num_heads seq_len_q_block"] = torch.maximum(
running_max, block_wise_max
)
rescale: Float[Tensor, "batch num_heads seq_len_q_block"] = (running_max - new_running_max).exp()
block_wise_exp_logits: Float[Tensor, "batch num_heads seq_len_q_block seq_len_kv_block"] = (
logits_block - new_running_max.unsqueeze(-1)
).exp()
new_running_denominator: Float[Tensor, "batch num_heads seq_len_q_block"] = (
running_denominator * rescale + block_wise_exp_logits.sum(dim=-1)
)
new_running_numerator: Float[Tensor, "batch num_heads seq_len_q_block value_dim"]
new_running_numerator = running_numerator * rescale.unsqueeze(-1) + einsum(
block_wise_exp_logits,
value_block,
"b h q v, b h v dv -> b h q dv",
)
return new_running_max, new_running_denominator, new_running_numerator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment