Skip to content

Instantly share code, notes, and snippets.

@createthis
Last active October 12, 2025 16:28
Show Gist options
  • Save createthis/69417474e24ca7a8096ce5a08227ab0c to your computer and use it in GitHub Desktop.
Save createthis/69417474e24ca7a8096ce5a08227ab0c to your computer and use it in GitHub Desktop.
Analysis of topk_selector.py by DeepSeek V3.1-Terminus when given DSA context

This code implements a high-performance Top-K selection algorithm using TileLang for GPU acceleration. I'll explain it line by line, focusing on the radix-based selection approach.

1. Imports and Configuration

import torch
import tilelang
import tilelang.language as T
pass_configs = {
    tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
}
  • Imports necessary libraries and disables thread storage synchronization for performance optimization.

2. Floating-Point to Unsigned Integer Conversion Functions

def convert_to_uint16(x):
    hval = T.Cast("float16", x)
    bits_uint = T.reinterpret("uint16", hval)
    bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000))
    return bits_uint >> 8
  • Converts float32 to uint16 by first casting to float16, then reinterpreting bits.
  • Handles sign bit: for negative numbers, flips bits and masks; for positives, sets sign bit.
  • Returns the high 8 bits (exponent and part of mantissa) for radix sorting.
def convert_to_uint32(x):
    bits_uint = T.reinterpret("uint32", x)
    bits_uint = T.if_then_else(
        x < 0,
        ~bits_uint & T.Cast("uint32", (0xFFFFFFFF)),
        bits_uint | T.Cast("uint32", (0x80000000)),
    )
    return bits_uint
  • Similar conversion for float32 to uint32, handling full 32-bit representation.
  • Used for finer-grained radix passes in stage 2.

3. Main Top-K Kernel

@tilelang.jit(pass_configs=pass_configs)
def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
    batch = T.symbolic("batch")
    seq_len = T.symbolic("seq_len")
    RADIX = 1 << 8
    BLOCK_SIZE = 1024
    SMEM_INPUT_SIZE = 4096  # assume the threshold bucket size after first pass is less than 4K
  • JIT-compiled TileLang kernel for Top-K selection.
  • Defines symbolic dimensions and constants: 256 radix bins, 1024 threads per block, 4096 shared memory size.
    @T.prim_func
    def tl_topk_kernel(
        input: T.Tensor[(batch, seq_len), in_dtype],
        index: T.Tensor[(batch, topk), out_dtype],
        starts: T.Tensor[(batch), out_dtype],
        ends: T.Tensor[(batch), out_dtype],
    ):
  • Kernel takes input tensor, output indices for top-k elements, and start/end indices for each batch.
        with T.Kernel(batch, threads=BLOCK_SIZE) as (bx):
            tx = T.get_thread_binding()
            s_threshold_bin_id = T.alloc_shared([1], "int32")
            s_histogram = T.alloc_shared([RADIX + 1], "int32")
            s_num_input = T.alloc_shared([2], "int32")
            s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32")
  • Launches kernel with one block per batch element.
  • Allocates shared memory for threshold bin ID, histogram, input count, and temporary indices.
            l_threshold_bin_id = T.alloc_var("int32")
            l_new_topk = T.alloc_var("int32")
            l_num_input = T.alloc_var("int32")
            l_bin_id32 = T.alloc_var("int32")
            l_val = T.alloc_var("int32")
            l_start_pos = T.alloc_var("int32")
            l_start_idx = T.alloc_var("int32")
            l_end_idx = T.alloc_var("int32")
            l_out_pos = T.alloc_var("int32")
  • Allocates local variables for computation.
            l_new_topk = topk
            l_start_idx = starts[bx]
            l_end_idx = ends[bx]
  • Initializes local variables with input parameters.

4. Stage 1: 8-bit Radix Pass

            # stage 1: use 8bit to do quick topk
            T.fill(s_histogram, 0)
            T.fill(s_num_input[0], 0)
            T.sync_threads()
  • Clears histogram and input counter, synchronizes threads.
            for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
                input_idx = s * BLOCK_SIZE + tx
                if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
                    inval_int16 = convert_to_uint16(input[bx, input_idx])
                    T.atomic_add(s_histogram[inval_int16], 1)
            T.sync_threads()
  • Iterates over input elements in blocks.
  • For each valid element, converts to uint16 and increments corresponding histogram bin.
            # cumsum
            if tx < RADIX:
                for i in T.serial(8):
                    offset = 1 << i
                    T.sync_threads(3, RADIX)
                    if tx < RADIX - offset:
                        l_val = s_histogram[tx] + s_histogram[tx + offset]
                    T.sync_threads(3, RADIX)
                    if tx < RADIX - offset:
                        s_histogram[tx] = l_val
  • Parallel prefix sum (cumulative sum) of histogram using Brent-Kung algorithm.
  • Each thread handles multiple bins in a tree reduction pattern.
                # find threshold bin id
                T.sync_threads(3, RADIX)
                if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
                    s_threshold_bin_id[0] = tx
            T.sync_threads()
  • Finds the bin where cumulative count exceeds top-k threshold.
  • This identifies the radix bin containing the k-th largest element.
            l_threshold_bin_id = s_threshold_bin_id[0]
            l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]
            T.sync_threads()
  • Updates local variables with threshold bin ID and remaining elements to select.
            # collect all elements with exponent ≥ threshold
            for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
                T.sync_threads()
                input_idx = s * BLOCK_SIZE + tx
                if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
                    bin_id = convert_to_uint16(input[bx, input_idx])
                    l_bin_id32 = T.Cast("int32", bin_id)
                    if l_bin_id32 > l_threshold_bin_id:
                        # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1)
                        pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True)
                        index[bx, pos] = input_idx
                    elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
                        # pos = s_num_input[0]
                        pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
                        s_input_idx[0, pos] = input_idx
  • Collects elements: those in higher bins go directly to output, elements in threshold bin go to temporary storage for stage 2.

5. Stage 2: Tail Pass with 4 Rounds of 8-bit Radix

            # stage 2: tail pass
            for round in T.serial(4):
                if l_new_topk <= 0:
                    T.loop_break()
                r_idx = round % 2
                l_start_pos = topk - l_new_topk
                T.sync_threads()
                T.fill(s_histogram, 0)
                if tx == 0:
                    s_num_input[r_idx ^ 1] = 0
                T.sync_threads()
  • Processes remaining elements from threshold bin through 4 rounds of finer radix passes.
  • Uses double buffering (r_idx) for input/output arrays.
                l_num_input = s_num_input[r_idx]
                for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
                    if s * BLOCK_SIZE + tx < l_num_input:
                        l_bin_id32 = T.Cast("int32", ((
                            convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >>
                            (24 - round * 8)) & 0xFF))
                        T.atomic_add(s_histogram[l_bin_id32], 1)
                T.sync_threads()
  • For each element in temporary storage, extracts 8-bit chunks from different positions (rounds 0-3 examine bits 24-31, 16-23, 8-15, 0-7).
  • Builds histogram for current 8-bit chunk.
                # cumsum
                if tx < RADIX:
                    for i in T.serial(8):
                        offset = 1 << i
                        T.sync_threads(3, RADIX)
                        if tx < RADIX - offset:
                            l_val = s_histogram[tx] + s_histogram[tx + offset]
                        T.sync_threads(3, RADIX)
                        if tx < RADIX - offset:
                            s_histogram[tx] = l_val
                    # find threshold bin id
                    T.sync_threads(3, RADIX)
                    if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
                        s_threshold_bin_id[0] = tx
                T.sync_threads()
                l_threshold_bin_id = s_threshold_bin_id[0]
                l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]
                T.sync_threads()
  • Same cumulative sum and threshold finding as stage 1, but on the current 8-bit chunk.
                for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
                    T.sync_threads()
                    if s * BLOCK_SIZE + tx < l_num_input:
                        l_bin_id32 = T.Cast("int32", ((
                            convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >>
                            (24 - round * 8)) & 0xFF))
                        if l_bin_id32 > l_threshold_bin_id:
                            pos = T.atomic_add(
                                s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
                            index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
                        elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
                            if round == 3:
                                l_out_pos = T.atomic_add(
                                    s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
                                if l_out_pos < topk:
                                    index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
                            else:
                                pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True)
                                s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx,
                                                                          s * BLOCK_SIZE + tx]
  • Distributes elements: higher bins go to output, threshold bin elements either go to output (final round) or next round's temporary storage.

6. Python Interface Function

def tl_topk(input, starts, ends, topk):
    batch, seq_len = input.shape
    indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device)
    kernel = tl_topk_impl(topk)
    kernel(input, indexes, starts, ends)
    return indexes
  • Wraps the TileLang kernel for easy PyTorch integration.
  • Allocates output tensor and calls the compiled kernel.

7. Test Function

def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
    batch = 64
    seq_len = 32 * 1024
    topk = 2048
    torch.manual_seed(1)
    input = torch.randn(batch, seq_len, dtype=torch.float32).cuda()
    starts = torch.zeros(batch, dtype=torch.int32).cuda()
    ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len
    indexes = tl_topk(input, starts, ends, topk)
    print(indexes)
    indexes_ref = torch.topk(input, topk, dim=-1)[1]
    print(indexes_ref)
  • Tests the implementation against PyTorch's built-in topk for correctness.
    # Calculate intersection of out_ref and out_trt
    for i in range(batch):
        ref_np = indexes_ref[i].cpu().to(torch.int32).numpy()
        trt_np = indexes[i].cpu().to(torch.int32).numpy()
        set_ref = set(ref_np)
        set_trt = set(trt_np)
        intersection = set_ref & set_trt
        print("selected/all:", len(intersection), "/", len(set_ref), "=",
              len(intersection) / len(set_ref))
  • Measures accuracy by comparing selected indices with reference implementation.
    # Performance test with CUDA events
    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    # Warmup
    for _ in range(5):
        _ = tl_topk(input, starts, ends, topk)
    torch.cuda.synchronize()
    n_iters = 20
    start_event.record()
    for _ in range(n_iters):
        _ = tl_topk(input, starts, ends, topk)
    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms")
    # Torch topk time
    start_event.record()
    for _ in range(n_iters):
        _ = torch.topk(input, topk, dim=-1)[1]
    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms")
  • Benchmarks performance against PyTorch's topk implementation.

8. Main Block

if __name__ == "__main__":
    test_topk_selector()
  • Runs the test when executed directly.

Summary

This code implements a highly optimized Top-K selection algorithm using a two-stage radix-based approach:

  1. Stage 1: Uses 8-bit radix sort to quickly identify the approximate threshold bin containing the k-th largest element. Elements in higher bins are directly selected.

  2. Stage 2: Processes the threshold bin through 4 rounds of 8-bit radix passes, examining different bit ranges of the floating-point representation for finer granularity.

The algorithm is particularly efficient for the DSA context because:

  • It handles the "fine-grained token selection mechanism" mentioned in your context
  • Works well with the index scores $I_{t,s}$ computed by the lightning indexer
  • The radix-based approach is faster than comparison-based sorting for large k values
  • GPU optimization through TileLang provides significant speedup over traditional implementations

The implementation shows competitive performance against PyTorch's built-in topk while maintaining high accuracy in selecting the top-k indices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment