Skip to content

Instantly share code, notes, and snippets.

@createthis
Last active October 12, 2025 14:13
Show Gist options
  • Save createthis/0cce8a250daa3a117cb2986c743c02f2 to your computer and use it in GitHub Desktop.
Save createthis/0cce8a250daa3a117cb2986c743c02f2 to your computer and use it in GitHub Desktop.
Analysis of fp8_lighting_indexer.py by DeepSeek V3.1-Terminus when given DSA context

This code implements the DeepSeek Sparse Attention (DSA) lightning indexer, which computes index scores for efficient attention using FP8 precision. I'll explain it line by line, breaking it into logical sections. The code uses TileLang (a DSL for GPU kernels) and PyTorch for high-performance computation.

1. Imports and Utility Functions

# ruff: noqa
import itertools
import tilelang
from tilelang import language as T
import torch
from utils import generate_random_cu_seqlens, per_custom_dims_cast_to_fp8
  • ruff: noqa ignores linting warnings.
  • Imports: itertools for configuration generation, tilelang for GPU kernels, torch for tensor operations, and custom utils for sequence length generation and FP8 casting.
def display_error_message(msg):
    print(f"\033[31mWARNING: {msg}\033[0m")
  • Prints warning messages in red using ANSI escape codes.
def compute_correlation(a, b, label="tensor"):
    a, b = a.data.double(), b.data.double()
    norm_sum = (a * a + b * b).sum()
    if norm_sum == 0:
        display_error_message(f"{label} all zero")
        return 1
    correlation = 2 * (a * b).sum() / norm_sum
    return correlation
  • Computes the correlation between two tensors after converting to double precision. Returns 1 if both are zero (perfect correlation).
def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_raise=True):
    a_finite = torch.isfinite(a)
    b_finite = torch.isfinite(b)
    if not torch.all(a_finite == b_finite):
        display_error_message(f"{tensor_name} Error: isfinite mask mismatch")
        if should_raise:
            assert False
    if not torch.isclose(
            a.masked_fill(a_finite, 0),
            b.masked_fill(b_finite, 0),
            rtol=0,
            atol=0,
            equal_nan=True,
    ).all():
        display_error_message(f"{tensor_name} Error: nonfinite value mismatch")
        if should_raise:
            assert False
    a = a.masked_fill(~a_finite, 0)
    b = b.masked_fill(~b_finite, 0)
    correlation = compute_correlation(a, b, tensor_name)
    difference = 1.0 - correlation
    if not (0 <= difference <= tolerance):
        display_error_message(f"{tensor_name} Error: {difference}")
        if should_raise:
            assert False
    return difference
  • Validates if two tensors match within a tolerance, handling non-finite values (NaN, inf). Checks finite masks, non-finite values, and correlation. Raises errors if mismatched.

2. Configuration Generation

def get_configs():
    iter_params = dict(
        block_N=[32, 64, 128],
        num_stages=[0, 1, 2],
        threads=[128, 256],
        block_Q=[1, 2, 4],
    )
    return [{
        k: v for k, v in zip(iter_params, values)
    } for values in itertools.product(*iter_params.values())]
  • Generates all combinations of kernel parameters (e.g., block sizes, thread counts) for tuning using itertools.product.

3. Tensor Supply Management

class SupplyProg:
    def __init__(self):
        self.tensors_dict = {}
    def get_key(self, shape, dtype) -> str:
        return f"{shape}-{dtype}"
    def supply_prog(self, params):
        shapes = [p.shape for p in params]
        dtypes = [p.dtype for p in params]
        tensor_list = []
        for shape, dtype in zip(shapes, dtypes):
            key = self.get_key(shape, dtype)
            if key not in self.tensors_dict:
                self.tensors_dict[key] = torch.randn(shape, dtype=dtype, device="cuda")
                tensor_list.append(self.tensors_dict[key])
            else:
                tensor_list.append(self.tensors_dict[key])
        return tensor_list

supply_prog = SupplyProg()
  • Caches tensors by shape and dtype to avoid recomputation during testing. supply_prog is a global instance.

4. Main Kernel: mqa_attn_return_logits

This TileLang kernel computes the index scores $I_{t,s}$ from the DSA formula.

@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    },)
def mqa_attn_return_logits(
    heads,
    index_dim,
    block_N=256,
    num_stages=3,
    threads=512,
    block_Q=None,
):
    if block_Q is None:
        block_Q = 128 // heads
    dtype = "float8_e4m3"
    accum_dtype = "float"
    index_dtype = "int32"
    seq_len = T.symbolic("seq_len")
    seq_len_kv = T.symbolic("seq_len_kv")
    index_q_shape = [seq_len * heads, index_dim]
    index_k_shape = [seq_len_kv, index_dim]
    index_k_scale_shape = [seq_len_kv]
    logits_shape = [seq_len, seq_len_kv]
  • Decorated with @tilelang.jit to compile to GPU code. Uses FP8 for storage and float32 for accumulation.
  • Defines symbolic dimensions for flexibility. block_Q defaults based on heads.
    @T.prim_func
    def mqa_attn_return_logits_kernel(
            IndexQ: T.Tensor(index_q_shape, dtype),
            IndexK: T.Tensor(index_k_shape, dtype),
            IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype),
            Logits: T.Tensor(logits_shape, accum_dtype),
            Weights: T.Tensor([seq_len, heads], accum_dtype),
            CuSeqLenKS: T.Tensor([seq_len], index_dtype),
            CuSeqLenKE: T.Tensor([seq_len], index_dtype),
    ):
  • Kernel function inputs: query indices (IndexQ), key indices (IndexK), key scales (IndexKScale), output logits (Logits), head weights (Weights), and cumulative sequence start/end indices (CuSeqLenKS, CuSeqLenKE).
        with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
            index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype)
            index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
            index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
            s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
            s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype)
            logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
            weights = T.alloc_fragment([block_Q, heads], accum_dtype)
  • Launches a kernel with grid size based on seq_len and block_Q. Allocates shared memory and fragments for efficient data access.
            seq_len_i = bx * block_Q
            cu_k_s_min = T.alloc_local([1], index_dtype)
            cu_k_e_max = T.alloc_local([1], index_dtype)
            T.no_set_max_nreg()
            cu_k_s_min[0] = 2147483647
            cu_k_e_max[0] = -2147483648
            for bq_i in T.serial(block_Q):
                cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
            for bq_i in T.serial(block_Q):
                cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))
  • Computes the min start and max end indices for the current block of queries to focus on relevant key positions.
            T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared)
            T.copy(Weights[seq_len_i, 0], weights)
  • Copies query indices and weights to shared memory.
            for nbn_i in T.Pipelined(
                    T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
                T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared)
                T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment)
  • Pipelines over key blocks in block_N chunks, copying key indices and scales to shared memory.
                T.gemm(
                    index_k_shared,
                    index_q_shared,
                    s,
                    transpose_B=True,
                    clear_accum=True,
                    policy=T.GemmWarpPolicy.FullCol,
                )
  • Performs matrix multiplication (GEMM) between keys and queries, storing result in s. This computes the dot product $\mathbf{q}^{I}_{t,j} \cdot \mathbf{k}^{I}_{s}$.
                for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
                    s_reshaped[bn_i, bq_i, h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
  • Applies ReLU (T.max(..., 0)), multiplies by weights $w^{I}_{t,j}$, and scales from FP8 quantization. This implements $\mathrm{ReLU}(\cdot) \cdot w^{I}_{t,j}$.
                T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
  • Sums over the heads dimension to get $\sum_{j=1}^{H^{I}} \cdots$, producing the index score $I_{t,s}$.
                for bq_i, bn_i in T.Parallel(block_Q, block_N):
                    Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = logits[bn_i, bq_i]
  • Writes the logits to global memory.

5. Clean Logits Kernel: clean_logits_

@tilelang.jit
def clean_logits_(
    threads: int = 512,
    block_K: int = 4096,
):
    seq_len = T.symbolic("seq_len")
    seq_len_kv = T.symbolic("seq_len_kv")
    dtype = "float"
    indices_dtype = "int32"
    @T.prim_func
    def clean_logits_kernel(
            Logits: T.Tensor([seq_len, seq_len_kv], dtype),
            CuSeqLenKS: T.Tensor([seq_len], indices_dtype),
            CuSeqLenKE: T.Tensor([seq_len], indices_dtype),
    ):
        with T.Kernel(seq_len, threads=threads) as bx:
            tx = T.thread_binding(0, threads, thread="threadIdx.x")
            cu_k_s = T.alloc_local([1], indices_dtype)
            cu_k_e = T.alloc_local([1], indices_dtype)
            cu_k_s[0] = CuSeqLenKS[bx]
            cu_k_e[0] = CuSeqLenKE[bx]
            for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)):
                for k_i in T.serial(block_K // threads):
                    idx = n_i * block_K + k_i * threads + tx
                    if idx < cu_k_s[0] or idx >= cu_k_e[0]:
                        Logits[bx, idx] = -T.infinity(dtype)
    return clean_logits_kernel
  • Masks logits outside the valid sequence range [cu_k_s, cu_k_e) by setting them to -inf. This ensures only relevant tokens are considered.

6. Interface Function

def mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True):
    seq_len, heads, index_dim = q.shape
    seq_len_kv = kv.shape[0]
    clean_logits_kernel = clean_logits_()
    mqa_attn_return_logits_kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
    logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32)
    mqa_attn_return_logits_kernel(
        q.view(seq_len * heads, index_dim),
        kv,
        kv_scales,
        logits,
        weights,
        cu_seqlen_ks,
        cu_seqlen_ke,
    )
    if clean_logits:
        clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke)
    return logits
  • Wraps the kernels for easy use: compiles them, reshapes inputs, calls the main kernel, and optionally applies masking.

7. Reference Implementation

def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
                       cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor):
    k = kv
    q = q.float()
    k = k.float()
    seq_len_kv = kv.shape[0]
    mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None]
    mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None]
    mask = mask_lo & mask_hi
    score = torch.einsum('mhd,nd->hmn', q, k)
    logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
    logits = logits.masked_fill(~mask, float('-inf'))
    cost = mask.sum()
    return logits, cost
  • Pure PyTorch implementation for validation: computes scores via einsum, applies ReLU and weights, sums over heads, and masks invalid positions. cost estimates FLOPs.

8. Test Function

def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
    q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
    kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
    weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
    p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1)  # Unused in test
    ks, ke = generate_random_cu_seqlens(
        per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048)
    logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
    q_fp8 = q.to(torch.float8_e4m3fn)
    kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False)
    logits_tl = mqa_attn_return_logits_interface(
        q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
    diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False)
    print(f"diff: {diff}")
    from tilelang.profiler import do_bench
    def logits_fn():
        return mqa_attn_return_logits_interface(
            q=q_fp8,
            kv=kv_fp8,
            kv_scales=kv_scales,
            weights=weights,
            cu_seqlen_ks=ks,
            cu_seqlen_ke=ke)
    with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
        logits_fn()
    print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50))
    logits_ms = do_bench(logits_fn, warmup=100, rep=100)
    logits_flops = 2 * cost_ref * H * D
    logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12
    print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}")
    print(f"cost_ref: {cost_ref}")
  • Tests the implementation: generates random data, computes reference and TileLang logits, validates correlation, profiles performance, and calculates TFLOPs.

9. Main Block

if __name__ == "__main__":
    test_fp8_lighting_indexer()
  • Runs the test when the script is executed directly.

Summary

This code efficiently implements the DSA lightning indexer using FP8 and TileLang for GPU acceleration. The kernels are optimized with tiling, pipelining, and shared memory. The reference implementation ensures correctness, and the test benchmarks performance. The core math matches Equation 1 from the context: $I_{t,s} = \sum_{j=1}^{H^{I}} w^{I}_{t,j} \cdot \mathrm{ReLU}(\mathbf{q}^{I}_{t,j} \cdot \mathbf{k}^{I}_{s})$.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment