Last active
August 6, 2025 14:09
-
-
Save vadimkantorov/19a6f22930705ed9a6497192eb44f304 to your computer and use it in GitHub Desktop.
Example of two-level aggregation for LogSumExp in Triton-lang (only forward pass), created for investigation of https://github.com/volcengine/verl/issues/2899
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Extracted and simplified the two-level aggregation approach (first, parallel aggregation in blocks, then final sequential aggregation) from https://github.com/volcengine/verl/blob/main/verl/utils/kernel/kernels.py | |
# Examples of single-level sequential, online aggregation approaches: | |
# - https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/cross_entropy.py | |
# - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py | |
# logsumexp_torch has some eager pseudo/code in PyTorch which emulates what Triton does, except that BLOCK_SIZE_M equials to M | |
# tl.program_id(axis=0).to(tl.int64) is used for https://arxiv.org/abs/2410.10989 and https://github.com/linkedin/Liger-Kernel/blob/05b43a14913ced3776aa3fc50020089b8c0d63c1/src/liger_kernel/ops/cross_entropy.py#L77-L79 | |
# sample_verl.pt is derived from the inputs (logits = torch.matmul(hidden, weights) uploaded by @WindowsXP-Beta in https://github.com/volcengine/verl/issues/2656#issuecomment-3131136498 ) | |
# created for investigation of https://github.com/volcengine/verl/issues/2899 | |
import torch | |
import triton | |
import triton.language as tl | |
@triton.autotune( | |
configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=3, num_warps=8)], | |
key=["num_tokens", "vocab_size"], | |
) | |
@triton.jit | |
def efficient_entropy_kernel_general_mainloop( | |
num_tokens, | |
vocab_size, | |
vocab_per_split, | |
logits_ptr, | |
stride_logits_m: tl.int64, | |
stride_logits_n: tl.int64, | |
max_ptr, | |
stride_max_m: tl.int64, | |
stride_max_n: tl.int64, | |
accu_ptr, | |
stride_accu_m: tl.int64, | |
stride_accu_n: tl.int64, | |
BLOCK_SIZE_M: tl.constexpr, | |
BLOCK_SIZE_N: tl.constexpr, | |
): | |
pid = tl.program_id(axis=0).to(tl.int64) | |
num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split | |
num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) | |
num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) | |
pid_m = pid % num_pid_m | |
pid_n = pid // num_pid_m | |
_max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32) | |
_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) | |
for n in range(0, num_pid_n): | |
offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
logits_ptrs = logits_ptr + offs_am[:, None] * stride_logits_m + offs_bn[None, :] * stride_logits_n | |
logits = tl.load(logits_ptrs) | |
_max_old = _max | |
m_pid_n = tl.max(logits, axis=1) | |
_max = tl.maximum(_max_old, m_pid_n) | |
exp_logits = tl.exp(logits - _max[:, None]) | |
coeff = tl.exp(_max_old - _max) | |
_accu = coeff * _accu + tl.sum(exp_logits, axis=1) | |
offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_max_n = pid_n | |
maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m | |
tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) | |
accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m | |
tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits)) | |
@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) | |
@triton.jit | |
def efficient_entropy_triton_kernel_epilogue( | |
max_ptr, | |
stride_max_m: tl.int64, | |
stride_max_n: tl.int64, | |
num_tokens, | |
num_splits, | |
accu_ptr, | |
stride_accu_m: tl.int64, | |
stride_accu_n: tl.int64, | |
global_logsumexp_ptr, | |
stride_global_logsumexp: tl.int64, | |
BLOCK_SIZE_M: tl.constexpr, | |
BLOCK_SIZE_N: tl.constexpr, | |
): | |
pid_m = tl.program_id(axis=0).to(tl.int64) | |
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) | |
global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) | |
for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): | |
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n | |
_max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) | |
accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n | |
_accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) | |
_max_old = global_max | |
_local_max = tl.max(_max, axis=1) | |
global_max = tl.maximum(global_max, _local_max) | |
_scale = tl.exp(_max - global_max[:, None]) | |
_coeff = tl.exp(_max_old - global_max) | |
global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) | |
global_logsumexp_ptrs = global_logsumexp_ptr + offs_m * stride_global_logsumexp | |
global_logsumexp = tl.load(global_logsumexp_ptrs, mask=offs_m < num_tokens) | |
global_logsumexp = global_max + tl.log(global_accu) | |
tl.store(global_logsumexp_ptrs, global_logsumexp, mask=offs_m < num_tokens) | |
def logsumexp_triton( | |
logits: torch.Tensor, | |
dim = -1, | |
vocab_per_split = 128 | |
) -> list[torch.Tensor]: | |
num_tokens, vocab_size = logits.shape | |
assert vocab_size % 128 == 0 | |
assert dim == -1 | |
device = logits.device | |
assert vocab_per_split % 128 == 0 | |
num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split | |
_logsumexp = torch.empty((num_tokens,), device=device, dtype=torch.float32) | |
_max = torch.empty((num_tokens, num_splits), device=device, dtype=torch.float32) | |
_accu = torch.empty((num_tokens, num_splits), device=device, dtype=torch.float32) | |
def mainloop_grid(meta): | |
return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) | |
efficient_entropy_kernel_general_mainloop[mainloop_grid]( | |
num_tokens, | |
vocab_size, | |
vocab_per_split, | |
logits, | |
logits.stride(0), | |
logits.stride(1), | |
_max, | |
_max.stride(0), | |
_max.stride(1), | |
_accu, | |
_accu.stride(0), | |
_accu.stride(1), | |
) | |
# reduction on maximum and maximum_indices | |
def epilogue_grid(meta): | |
return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) | |
efficient_entropy_triton_kernel_epilogue[epilogue_grid]( | |
_max, | |
_max.stride(0), | |
_max.stride(1), | |
num_tokens, | |
num_splits, | |
_accu, | |
_accu.stride(0), | |
_accu.stride(1), | |
_logsumexp, | |
_logsumexp.stride(0), | |
) | |
return _logsumexp | |
#return (logsumexp, _max, _accu) | |
def logsumexp_torch(logits2d, dim = -1, vocab_per_split = 128): | |
num_tokens, vocab_size = logits2d.shape | |
assert vocab_per_split % 128 == 0 | |
assert vocab_size % 128 == 0 | |
assert dim == -1 | |
num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split | |
BLOCK_SIZE_M = num_tokens | |
device = logits2d.device | |
_max = torch.empty((num_tokens, num_splits), device=device, dtype=torch.float32) | |
_accu = torch.empty((num_tokens, num_splits), device=device, dtype=torch.float32) | |
BLOCK_SIZE_N = 128 | |
for split_idx, split in enumerate(logits2d.split(vocab_per_split, dim = -1)): | |
__max = torch.full((BLOCK_SIZE_M,), -float("inf"), dtype=torch.float32, device=device) | |
__accu = torch.zeros((BLOCK_SIZE_M,), dtype=torch.float32, device=logits2d.device) | |
for logits in split.split(BLOCK_SIZE_N, dim = -1): | |
_max_old = __max | |
m_pid_n = torch.amax(logits, dim=1) | |
__max = torch.maximum(_max_old, m_pid_n) | |
exp_logits = torch.exp(logits - __max[:, None]) | |
coeff = torch.exp(_max_old - __max) | |
__accu = coeff * __accu + torch.sum(exp_logits, axis=1) | |
_max[:, split_idx] = __max | |
_accu[:, split_idx] = __accu | |
BLOCK_SIZE_N = 64 | |
global_max = torch.zeros((BLOCK_SIZE_M,), dtype=torch.float32, device=device) | |
global_accu = torch.zeros((BLOCK_SIZE_M,), dtype=torch.float32, device=device) | |
for __max, __accu in zip(_max.split(BLOCK_SIZE_N, dim = -1), _accu.split(BLOCK_SIZE_N, dim = -1)): | |
_max_old = global_max | |
_local_max = torch.amax(__max, dim=1) | |
global_max = torch.maximum(global_max, _local_max) | |
_scale = torch.exp(__max - global_max[:, None]) | |
_coeff = torch.exp(_max_old - global_max) | |
global_accu = _coeff * global_accu + torch.sum(_scale * __accu, dim=1) | |
return torch.log(global_accu) + global_max | |
if __name__ == '__main__': | |
logits = torch.load('sample_verl.pt').squeeze(0) | |
#logits = torch.rand(0, 2, 129920, dtype = torch.float32, device = 'cuda').squeeze(0) # divides by 128, but not 256 or 1024 | |
assert logits.ndim == 2 | |
a = torch.logsumexp(logits, dim = -1) | |
b = logsumexp_torch(logits, dim = -1) | |
print(a, b) | |
print((a - b).abs()) | |
c = logsumexp_triton(logits, dim = -1) | |
print(c, (a - c).abs()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment