This code implements the DeepSeek Sparse Attention (DSA) lightning indexer, which computes index scores for efficient attention using FP8 precision. I'll explain it line by line, breaking it into logical sections. The code uses TileLang (a DSL for GPU kernels) and PyTorch for high-performance computation.
# ruff: noqa
import itertools
import tilelang
from tilelang import language as T
import torch
from utils import generate_random_cu_seqlens, per_custom_dims_cast_to_fp8ruff: noqaignores linting warnings.- Imports:
itertoolsfor configuration generation,tilelangfor GPU kernels,torchfor tensor operations, and custom utils for sequence length generation and FP8 casting.
def display_error_message(msg):
print(f"\033[31mWARNING: {msg}\033[0m")- Prints warning messages in red using ANSI escape codes.
def compute_correlation(a, b, label="tensor"):
a, b = a.data.double(), b.data.double()
norm_sum = (a * a + b * b).sum()
if norm_sum == 0:
display_error_message(f"{label} all zero")
return 1
correlation = 2 * (a * b).sum() / norm_sum
return correlation- Computes the correlation between two tensors after converting to double precision. Returns 1 if both are zero (perfect correlation).
def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_raise=True):
a_finite = torch.isfinite(a)
b_finite = torch.isfinite(b)
if not torch.all(a_finite == b_finite):
display_error_message(f"{tensor_name} Error: isfinite mask mismatch")
if should_raise:
assert False
if not torch.isclose(
a.masked_fill(a_finite, 0),
b.masked_fill(b_finite, 0),
rtol=0,
atol=0,
equal_nan=True,
).all():
display_error_message(f"{tensor_name} Error: nonfinite value mismatch")
if should_raise:
assert False
a = a.masked_fill(~a_finite, 0)
b = b.masked_fill(~b_finite, 0)
correlation = compute_correlation(a, b, tensor_name)
difference = 1.0 - correlation
if not (0 <= difference <= tolerance):
display_error_message(f"{tensor_name} Error: {difference}")
if should_raise:
assert False
return difference- Validates if two tensors match within a tolerance, handling non-finite values (NaN, inf). Checks finite masks, non-finite values, and correlation. Raises errors if mismatched.
def get_configs():
iter_params = dict(
block_N=[32, 64, 128],
num_stages=[0, 1, 2],
threads=[128, 256],
block_Q=[1, 2, 4],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]- Generates all combinations of kernel parameters (e.g., block sizes, thread counts) for tuning using
itertools.product.
class SupplyProg:
def __init__(self):
self.tensors_dict = {}
def get_key(self, shape, dtype) -> str:
return f"{shape}-{dtype}"
def supply_prog(self, params):
shapes = [p.shape for p in params]
dtypes = [p.dtype for p in params]
tensor_list = []
for shape, dtype in zip(shapes, dtypes):
key = self.get_key(shape, dtype)
if key not in self.tensors_dict:
self.tensors_dict[key] = torch.randn(shape, dtype=dtype, device="cuda")
tensor_list.append(self.tensors_dict[key])
else:
tensor_list.append(self.tensors_dict[key])
return tensor_list
supply_prog = SupplyProg()- Caches tensors by shape and dtype to avoid recomputation during testing.
supply_progis a global instance.
This TileLang kernel computes the index scores
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},)
def mqa_attn_return_logits(
heads,
index_dim,
block_N=256,
num_stages=3,
threads=512,
block_Q=None,
):
if block_Q is None:
block_Q = 128 // heads
dtype = "float8_e4m3"
accum_dtype = "float"
index_dtype = "int32"
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
index_q_shape = [seq_len * heads, index_dim]
index_k_shape = [seq_len_kv, index_dim]
index_k_scale_shape = [seq_len_kv]
logits_shape = [seq_len, seq_len_kv]- Decorated with
@tilelang.jitto compile to GPU code. Uses FP8 for storage and float32 for accumulation. - Defines symbolic dimensions for flexibility.
block_Qdefaults based on heads.
@T.prim_func
def mqa_attn_return_logits_kernel(
IndexQ: T.Tensor(index_q_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype),
Logits: T.Tensor(logits_shape, accum_dtype),
Weights: T.Tensor([seq_len, heads], accum_dtype),
CuSeqLenKS: T.Tensor([seq_len], index_dtype),
CuSeqLenKE: T.Tensor([seq_len], index_dtype),
):- Kernel function inputs: query indices (
IndexQ), key indices (IndexK), key scales (IndexKScale), output logits (Logits), head weights (Weights), and cumulative sequence start/end indices (CuSeqLenKS,CuSeqLenKE).
with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype)
index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype)
logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
weights = T.alloc_fragment([block_Q, heads], accum_dtype)- Launches a kernel with grid size based on
seq_lenandblock_Q. Allocates shared memory and fragments for efficient data access.
seq_len_i = bx * block_Q
cu_k_s_min = T.alloc_local([1], index_dtype)
cu_k_e_max = T.alloc_local([1], index_dtype)
T.no_set_max_nreg()
cu_k_s_min[0] = 2147483647
cu_k_e_max[0] = -2147483648
for bq_i in T.serial(block_Q):
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
for bq_i in T.serial(block_Q):
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))- Computes the min start and max end indices for the current block of queries to focus on relevant key positions.
T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared)
T.copy(Weights[seq_len_i, 0], weights)- Copies query indices and weights to shared memory.
for nbn_i in T.Pipelined(
T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared)
T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment)- Pipelines over key blocks in
block_Nchunks, copying key indices and scales to shared memory.
T.gemm(
index_k_shared,
index_q_shared,
s,
transpose_B=True,
clear_accum=True,
policy=T.GemmWarpPolicy.FullCol,
)- Performs matrix multiplication (GEMM) between keys and queries, storing result in
s. This computes the dot product$\mathbf{q}^{I}_{t,j} \cdot \mathbf{k}^{I}_{s}$ .
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i, h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]- Applies ReLU (
T.max(..., 0)), multiplies by weights$w^{I}_{t,j}$ , and scales from FP8 quantization. This implements$\mathrm{ReLU}(\cdot) \cdot w^{I}_{t,j}$ .
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)- Sums over the heads dimension to get
$\sum_{j=1}^{H^{I}} \cdots$ , producing the index score$I_{t,s}$ .
for bq_i, bn_i in T.Parallel(block_Q, block_N):
Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = logits[bn_i, bq_i]- Writes the logits to global memory.
@tilelang.jit
def clean_logits_(
threads: int = 512,
block_K: int = 4096,
):
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
dtype = "float"
indices_dtype = "int32"
@T.prim_func
def clean_logits_kernel(
Logits: T.Tensor([seq_len, seq_len_kv], dtype),
CuSeqLenKS: T.Tensor([seq_len], indices_dtype),
CuSeqLenKE: T.Tensor([seq_len], indices_dtype),
):
with T.Kernel(seq_len, threads=threads) as bx:
tx = T.thread_binding(0, threads, thread="threadIdx.x")
cu_k_s = T.alloc_local([1], indices_dtype)
cu_k_e = T.alloc_local([1], indices_dtype)
cu_k_s[0] = CuSeqLenKS[bx]
cu_k_e[0] = CuSeqLenKE[bx]
for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)):
for k_i in T.serial(block_K // threads):
idx = n_i * block_K + k_i * threads + tx
if idx < cu_k_s[0] or idx >= cu_k_e[0]:
Logits[bx, idx] = -T.infinity(dtype)
return clean_logits_kernel- Masks logits outside the valid sequence range [cu_k_s, cu_k_e) by setting them to -inf. This ensures only relevant tokens are considered.
def mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True):
seq_len, heads, index_dim = q.shape
seq_len_kv = kv.shape[0]
clean_logits_kernel = clean_logits_()
mqa_attn_return_logits_kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32)
mqa_attn_return_logits_kernel(
q.view(seq_len * heads, index_dim),
kv,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
if clean_logits:
clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke)
return logits- Wraps the kernels for easy use: compiles them, reshapes inputs, calls the main kernel, and optionally applies masking.
def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor):
k = kv
q = q.float()
k = k.float()
seq_len_kv = kv.shape[0]
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None]
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None]
mask = mask_lo & mask_hi
score = torch.einsum('mhd,nd->hmn', q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float('-inf'))
cost = mask.sum()
return logits, cost- Pure PyTorch implementation for validation: computes scores via einsum, applies ReLU and weights, sums over heads, and masks invalid positions.
costestimates FLOPs.
def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) # Unused in test
ks, ke = generate_random_cu_seqlens(
per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048)
logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits_tl = mqa_attn_return_logits_interface(
q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False)
print(f"diff: {diff}")
from tilelang.profiler import do_bench
def logits_fn():
return mqa_attn_return_logits_interface(
q=q_fp8,
kv=kv_fp8,
kv_scales=kv_scales,
weights=weights,
cu_seqlen_ks=ks,
cu_seqlen_ke=ke)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
logits_fn()
print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50))
logits_ms = do_bench(logits_fn, warmup=100, rep=100)
logits_flops = 2 * cost_ref * H * D
logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12
print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}")
print(f"cost_ref: {cost_ref}")- Tests the implementation: generates random data, computes reference and TileLang logits, validates correlation, profiles performance, and calculates TFLOPs.
if __name__ == "__main__":
test_fp8_lighting_indexer()- Runs the test when the script is executed directly.
This code efficiently implements the DSA lightning indexer using FP8 and TileLang for GPU acceleration. The kernels are optimized with tiling, pipelining, and shared memory. The reference implementation ensures correctness, and the test benchmarks performance. The core math matches Equation 1 from the context: