Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active August 6, 2025 14:09
Show Gist options
  • Save vadimkantorov/19a6f22930705ed9a6497192eb44f304 to your computer and use it in GitHub Desktop.
Save vadimkantorov/19a6f22930705ed9a6497192eb44f304 to your computer and use it in GitHub Desktop.
Example of two-level aggregation for LogSumExp in Triton-lang (only forward pass), created for investigation of https://github.com/volcengine/verl/issues/2899
# Extracted and simplified the two-level aggregation approach (first, parallel aggregation in blocks, then final sequential aggregation) from https://github.com/volcengine/verl/blob/main/verl/utils/kernel/kernels.py
# Examples of single-level sequential, online aggregation approaches:
# - https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/cross_entropy.py
# - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py
# logsumexp_torch has some eager pseudo/code in PyTorch which emulates what Triton does, except that BLOCK_SIZE_M equials to M
# tl.program_id(axis=0).to(tl.int64) is used for https://arxiv.org/abs/2410.10989 and https://github.com/linkedin/Liger-Kernel/blob/05b43a14913ced3776aa3fc50020089b8c0d63c1/src/liger_kernel/ops/cross_entropy.py#L77-L79
# sample_verl.pt is derived from the inputs (logits = torch.matmul(hidden, weights) uploaded by @WindowsXP-Beta in https://github.com/volcengine/verl/issues/2656#issuecomment-3131136498 )
# created for investigation of https://github.com/volcengine/verl/issues/2899
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=3, num_warps=8)],
key=["num_tokens", "vocab_size"],
)
@triton.jit
def efficient_entropy_kernel_general_mainloop(
num_tokens,
vocab_size,
vocab_per_split,
logits_ptr,
stride_logits_m: tl.int64,
stride_logits_n: tl.int64,
max_ptr,
stride_max_m: tl.int64,
stride_max_n: tl.int64,
accu_ptr,
stride_accu_m: tl.int64,
stride_accu_n: tl.int64,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(axis=0).to(tl.int64)
num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split
num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
_max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32)
_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
for n in range(0, num_pid_n):
offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
logits_ptrs = logits_ptr + offs_am[:, None] * stride_logits_m + offs_bn[None, :] * stride_logits_n
logits = tl.load(logits_ptrs)
_max_old = _max
m_pid_n = tl.max(logits, axis=1)
_max = tl.maximum(_max_old, m_pid_n)
exp_logits = tl.exp(logits - _max[:, None])
coeff = tl.exp(_max_old - _max)
_accu = coeff * _accu + tl.sum(exp_logits, axis=1)
offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_max_n = pid_n
maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m
tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits))
accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m
tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits))
@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"])
@triton.jit
def efficient_entropy_triton_kernel_epilogue(
max_ptr,
stride_max_m: tl.int64,
stride_max_n: tl.int64,
num_tokens,
num_splits,
accu_ptr,
stride_accu_m: tl.int64,
stride_accu_n: tl.int64,
global_logsumexp_ptr,
stride_global_logsumexp: tl.int64,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid_m = tl.program_id(axis=0).to(tl.int64)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)):
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n
_max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0)
accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n
_accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0)
_max_old = global_max
_local_max = tl.max(_max, axis=1)
global_max = tl.maximum(global_max, _local_max)
_scale = tl.exp(_max - global_max[:, None])
_coeff = tl.exp(_max_old - global_max)
global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1)
global_logsumexp_ptrs = global_logsumexp_ptr + offs_m * stride_global_logsumexp
global_logsumexp = tl.load(global_logsumexp_ptrs, mask=offs_m < num_tokens)
global_logsumexp = global_max + tl.log(global_accu)
tl.store(global_logsumexp_ptrs, global_logsumexp, mask=offs_m < num_tokens)
def logsumexp_triton(
logits: torch.Tensor,
dim = -1,
vocab_per_split = 128
) -> list[torch.Tensor]:
num_tokens, vocab_size = logits.shape
assert vocab_size % 128 == 0
assert dim == -1
device = logits.device
assert vocab_per_split % 128 == 0
num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split
_logsumexp = torch.empty((num_tokens,), device=device, dtype=torch.float32)
_max = torch.empty((num_tokens, num_splits), device=device, dtype=torch.float32)
_accu = torch.empty((num_tokens, num_splits), device=device, dtype=torch.float32)
def mainloop_grid(meta):
return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,)
efficient_entropy_kernel_general_mainloop[mainloop_grid](
num_tokens,
vocab_size,
vocab_per_split,
logits,
logits.stride(0),
logits.stride(1),
_max,
_max.stride(0),
_max.stride(1),
_accu,
_accu.stride(0),
_accu.stride(1),
)
# reduction on maximum and maximum_indices
def epilogue_grid(meta):
return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),)
efficient_entropy_triton_kernel_epilogue[epilogue_grid](
_max,
_max.stride(0),
_max.stride(1),
num_tokens,
num_splits,
_accu,
_accu.stride(0),
_accu.stride(1),
_logsumexp,
_logsumexp.stride(0),
)
return _logsumexp
#return (logsumexp, _max, _accu)
def logsumexp_torch(logits2d, dim = -1, vocab_per_split = 128):
num_tokens, vocab_size = logits2d.shape
assert vocab_per_split % 128 == 0
assert vocab_size % 128 == 0
assert dim == -1
num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split
BLOCK_SIZE_M = num_tokens
device = logits2d.device
_max = torch.empty((num_tokens, num_splits), device=device, dtype=torch.float32)
_accu = torch.empty((num_tokens, num_splits), device=device, dtype=torch.float32)
BLOCK_SIZE_N = 128
for split_idx, split in enumerate(logits2d.split(vocab_per_split, dim = -1)):
__max = torch.full((BLOCK_SIZE_M,), -float("inf"), dtype=torch.float32, device=device)
__accu = torch.zeros((BLOCK_SIZE_M,), dtype=torch.float32, device=logits2d.device)
for logits in split.split(BLOCK_SIZE_N, dim = -1):
_max_old = __max
m_pid_n = torch.amax(logits, dim=1)
__max = torch.maximum(_max_old, m_pid_n)
exp_logits = torch.exp(logits - __max[:, None])
coeff = torch.exp(_max_old - __max)
__accu = coeff * __accu + torch.sum(exp_logits, axis=1)
_max[:, split_idx] = __max
_accu[:, split_idx] = __accu
BLOCK_SIZE_N = 64
global_max = torch.zeros((BLOCK_SIZE_M,), dtype=torch.float32, device=device)
global_accu = torch.zeros((BLOCK_SIZE_M,), dtype=torch.float32, device=device)
for __max, __accu in zip(_max.split(BLOCK_SIZE_N, dim = -1), _accu.split(BLOCK_SIZE_N, dim = -1)):
_max_old = global_max
_local_max = torch.amax(__max, dim=1)
global_max = torch.maximum(global_max, _local_max)
_scale = torch.exp(__max - global_max[:, None])
_coeff = torch.exp(_max_old - global_max)
global_accu = _coeff * global_accu + torch.sum(_scale * __accu, dim=1)
return torch.log(global_accu) + global_max
if __name__ == '__main__':
logits = torch.load('sample_verl.pt').squeeze(0)
#logits = torch.rand(0, 2, 129920, dtype = torch.float32, device = 'cuda').squeeze(0) # divides by 128, but not 256 or 1024
assert logits.ndim == 2
a = torch.logsumexp(logits, dim = -1)
b = logsumexp_torch(logits, dim = -1)
print(a, b)
print((a - b).abs())
c = logsumexp_triton(logits, dim = -1)
print(c, (a - c).abs())
This file has been truncated, but you can view the full file.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment