This code implements a high-performance Top-K selection algorithm using TileLang for GPU acceleration. I'll explain it line by line, focusing on the radix-based selection approach.
import torch
import tilelang
import tilelang.language as T
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
}- Imports necessary libraries and disables thread storage synchronization for performance optimization.
def convert_to_uint16(x):
hval = T.Cast("float16", x)
bits_uint = T.reinterpret("uint16", hval)
bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000))
return bits_uint >> 8- Converts float32 to uint16 by first casting to float16, then reinterpreting bits.
- Handles sign bit: for negative numbers, flips bits and masks; for positives, sets sign bit.
- Returns the high 8 bits (exponent and part of mantissa) for radix sorting.
def convert_to_uint32(x):
bits_uint = T.reinterpret("uint32", x)
bits_uint = T.if_then_else(
x < 0,
~bits_uint & T.Cast("uint32", (0xFFFFFFFF)),
bits_uint | T.Cast("uint32", (0x80000000)),
)
return bits_uint- Similar conversion for float32 to uint32, handling full 32-bit representation.
- Used for finer-grained radix passes in stage 2.
@tilelang.jit(pass_configs=pass_configs)
def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
batch = T.symbolic("batch")
seq_len = T.symbolic("seq_len")
RADIX = 1 << 8
BLOCK_SIZE = 1024
SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K- JIT-compiled TileLang kernel for Top-K selection.
- Defines symbolic dimensions and constants: 256 radix bins, 1024 threads per block, 4096 shared memory size.
@T.prim_func
def tl_topk_kernel(
input: T.Tensor[(batch, seq_len), in_dtype],
index: T.Tensor[(batch, topk), out_dtype],
starts: T.Tensor[(batch), out_dtype],
ends: T.Tensor[(batch), out_dtype],
):- Kernel takes input tensor, output indices for top-k elements, and start/end indices for each batch.
with T.Kernel(batch, threads=BLOCK_SIZE) as (bx):
tx = T.get_thread_binding()
s_threshold_bin_id = T.alloc_shared([1], "int32")
s_histogram = T.alloc_shared([RADIX + 1], "int32")
s_num_input = T.alloc_shared([2], "int32")
s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32")- Launches kernel with one block per batch element.
- Allocates shared memory for threshold bin ID, histogram, input count, and temporary indices.
l_threshold_bin_id = T.alloc_var("int32")
l_new_topk = T.alloc_var("int32")
l_num_input = T.alloc_var("int32")
l_bin_id32 = T.alloc_var("int32")
l_val = T.alloc_var("int32")
l_start_pos = T.alloc_var("int32")
l_start_idx = T.alloc_var("int32")
l_end_idx = T.alloc_var("int32")
l_out_pos = T.alloc_var("int32")- Allocates local variables for computation.
l_new_topk = topk
l_start_idx = starts[bx]
l_end_idx = ends[bx]- Initializes local variables with input parameters.
# stage 1: use 8bit to do quick topk
T.fill(s_histogram, 0)
T.fill(s_num_input[0], 0)
T.sync_threads()- Clears histogram and input counter, synchronizes threads.
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
input_idx = s * BLOCK_SIZE + tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
inval_int16 = convert_to_uint16(input[bx, input_idx])
T.atomic_add(s_histogram[inval_int16], 1)
T.sync_threads()- Iterates over input elements in blocks.
- For each valid element, converts to uint16 and increments corresponding histogram bin.
# cumsum
if tx < RADIX:
for i in T.serial(8):
offset = 1 << i
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
l_val = s_histogram[tx] + s_histogram[tx + offset]
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
s_histogram[tx] = l_val- Parallel prefix sum (cumulative sum) of histogram using Brent-Kung algorithm.
- Each thread handles multiple bins in a tree reduction pattern.
# find threshold bin id
T.sync_threads(3, RADIX)
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
T.sync_threads()- Finds the bin where cumulative count exceeds top-k threshold.
- This identifies the radix bin containing the k-th largest element.
l_threshold_bin_id = s_threshold_bin_id[0]
l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]
T.sync_threads()- Updates local variables with threshold bin ID and remaining elements to select.
# collect all elements with exponent ≥ threshold
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
T.sync_threads()
input_idx = s * BLOCK_SIZE + tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
bin_id = convert_to_uint16(input[bx, input_idx])
l_bin_id32 = T.Cast("int32", bin_id)
if l_bin_id32 > l_threshold_bin_id:
# need a pos = T.atomic_add(s_histogram[bin_id32+1], 1)
pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True)
index[bx, pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
# pos = s_num_input[0]
pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
s_input_idx[0, pos] = input_idx- Collects elements: those in higher bins go directly to output, elements in threshold bin go to temporary storage for stage 2.
# stage 2: tail pass
for round in T.serial(4):
if l_new_topk <= 0:
T.loop_break()
r_idx = round % 2
l_start_pos = topk - l_new_topk
T.sync_threads()
T.fill(s_histogram, 0)
if tx == 0:
s_num_input[r_idx ^ 1] = 0
T.sync_threads()- Processes remaining elements from threshold bin through 4 rounds of finer radix passes.
- Uses double buffering (r_idx) for input/output arrays.
l_num_input = s_num_input[r_idx]
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast("int32", ((
convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >>
(24 - round * 8)) & 0xFF))
T.atomic_add(s_histogram[l_bin_id32], 1)
T.sync_threads()- For each element in temporary storage, extracts 8-bit chunks from different positions (rounds 0-3 examine bits 24-31, 16-23, 8-15, 0-7).
- Builds histogram for current 8-bit chunk.
# cumsum
if tx < RADIX:
for i in T.serial(8):
offset = 1 << i
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
l_val = s_histogram[tx] + s_histogram[tx + offset]
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
s_histogram[tx] = l_val
# find threshold bin id
T.sync_threads(3, RADIX)
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
T.sync_threads()
l_threshold_bin_id = s_threshold_bin_id[0]
l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]
T.sync_threads()- Same cumulative sum and threshold finding as stage 1, but on the current 8-bit chunk.
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
T.sync_threads()
if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast("int32", ((
convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >>
(24 - round * 8)) & 0xFF))
if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add(
s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
if round == 3:
l_out_pos = T.atomic_add(
s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
if l_out_pos < topk:
index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
else:
pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True)
s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx,
s * BLOCK_SIZE + tx]- Distributes elements: higher bins go to output, threshold bin elements either go to output (final round) or next round's temporary storage.
def tl_topk(input, starts, ends, topk):
batch, seq_len = input.shape
indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device)
kernel = tl_topk_impl(topk)
kernel(input, indexes, starts, ends)
return indexes- Wraps the TileLang kernel for easy PyTorch integration.
- Allocates output tensor and calls the compiled kernel.
def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
batch = 64
seq_len = 32 * 1024
topk = 2048
torch.manual_seed(1)
input = torch.randn(batch, seq_len, dtype=torch.float32).cuda()
starts = torch.zeros(batch, dtype=torch.int32).cuda()
ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len
indexes = tl_topk(input, starts, ends, topk)
print(indexes)
indexes_ref = torch.topk(input, topk, dim=-1)[1]
print(indexes_ref)- Tests the implementation against PyTorch's built-in topk for correctness.
# Calculate intersection of out_ref and out_trt
for i in range(batch):
ref_np = indexes_ref[i].cpu().to(torch.int32).numpy()
trt_np = indexes[i].cpu().to(torch.int32).numpy()
set_ref = set(ref_np)
set_trt = set(trt_np)
intersection = set_ref & set_trt
print("selected/all:", len(intersection), "/", len(set_ref), "=",
len(intersection) / len(set_ref))- Measures accuracy by comparing selected indices with reference implementation.
# Performance test with CUDA events
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(5):
_ = tl_topk(input, starts, ends, topk)
torch.cuda.synchronize()
n_iters = 20
start_event.record()
for _ in range(n_iters):
_ = tl_topk(input, starts, ends, topk)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms")
# Torch topk time
start_event.record()
for _ in range(n_iters):
_ = torch.topk(input, topk, dim=-1)[1]
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms")- Benchmarks performance against PyTorch's topk implementation.
if __name__ == "__main__":
test_topk_selector()- Runs the test when executed directly.
This code implements a highly optimized Top-K selection algorithm using a two-stage radix-based approach:
-
Stage 1: Uses 8-bit radix sort to quickly identify the approximate threshold bin containing the k-th largest element. Elements in higher bins are directly selected.
-
Stage 2: Processes the threshold bin through 4 rounds of 8-bit radix passes, examining different bit ranges of the floating-point representation for finer granularity.
The algorithm is particularly efficient for the DSA context because:
- It handles the "fine-grained token selection mechanism" mentioned in your context
- Works well with the index scores
$I_{t,s}$ computed by the lightning indexer - The radix-based approach is faster than comparison-based sorting for large k values
- GPU optimization through TileLang provides significant speedup over traditional implementations
The implementation shows competitive performance against PyTorch's built-in topk while maintaining high accuracy in selecting the top-k indices.