Skip to content

Instantly share code, notes, and snippets.

@3outeille
Created August 20, 2025 13:43
Show Gist options
  • Save 3outeille/66410f907901a5126c78a50d48f2f5ab to your computer and use it in GitHub Desktop.
Save 3outeille/66410f907901a5126c78a50d48f2f5ab to your computer and use it in GitHub Desktop.
import torch
import triton
import triton.language as tl
def assert_is_matrix(x):
if x.ndim != 2:
raise ValueError(f'Expected 2-tensor but got {x.ndim}-tensor')
def assert_is_vector(x):
if x.ndim != 1:
raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
def assert_equal(a, b):
if a != b:
raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
@triton.autotune(
configs=[
triton.Config({'BLOCK_X': 64}, num_warps=2),
triton.Config({'BLOCK_X': 128}, num_warps=2),
triton.Config({'BLOCK_X': 256}, num_warps=2),
triton.Config({'BLOCK_X': 128}, num_warps=4),
triton.Config({'BLOCK_X': 256}, num_warps=4),
],
key=['NUM_COLUMNS'],
)
@triton.jit
def _binned_copy(
a,
b,
num_experts,
expert_capacity,
indices,
weights,
bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
A_TO_B: tl.constexpr,
SCALE: tl.constexpr,
):
# Load our indices into the output.
expert_idx = tl.program_id(0)
entry_idx = tl.program_id(1)
# Calculate our offset into the output.
index_b = expert_idx * expert_capacity + entry_idx
# Load the index bounds for our bin and calculate
# the number of tokens assigned to our expert.
start = 0
if expert_idx > 0:
start = tl.load(bins + expert_idx - 1)
end = tl.load(bins + expert_idx)
num_tokens = end - start
# Calculate our offset into the input. If we don't
# have an input exit early.
if entry_idx >= num_tokens:
return
index_a = tl.load(indices + start + entry_idx)
# Offset the input and output pointers.
#
# If we're going from A to B, divide the input index to copy
# the same input repeatedly. If we're going from B to A we
# need to reduce the result. Using atomics is slow, so we
# do the reduce step in a second kernel.
offset = index_a // TOP_K if A_TO_B else index_a
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
# Load the scale, if requested.
scale = tl.load(weights + index_a) if SCALE else 1
# Swap the pointers depending on the direction.
#
# NOTE: We need to zero the output in both directions.
iptr = a if A_TO_B else b
optr = b if A_TO_B else a
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for _ in range(iterations):
mask = offsets < NUM_COLUMNS
x = tl.load(iptr + offsets, mask=mask)
x = x.to(tl.float32) * scale.to(tl.float32)
tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
offsets += BLOCK_X
def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
assert_is_vector(bins)
assert bins.shape[0] > 0, "bins must not be empty"
assert_equal(indices.shape[0], x.shape[0] * top_k)
if weights is not None:
assert_equal(weights.shape[0], x.shape[0] * top_k)
num_experts = bins.shape[0]
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
assert expert_capacity > 0, "expert_capacity must be > 0 to launch triton kernel"
_binned_copy[(num_experts, expert_capacity)](
x,
out,
num_experts,
expert_capacity,
indices,
weights,
bins,
NUM_COLUMNS=x.shape[1],
A_TO_B=True,
TOP_K=top_k,
SCALE=weights is not None,
)
return out
if __name__ == "__main__":
# Create tensors with shapes and dtypes from the image.
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
num_tokens = 15038
hidden_size = 2880
num_experts = 1 # <= 2^31-1
expert_capacity = 65535 # <= 65535
top_k = 1
x = torch.randn(
(num_tokens, hidden_size), dtype=torch.bfloat16, device='cuda')
# For this test, we can use a simple range for indices.
indices = torch.arange(num_tokens, dtype=torch.int32, device='cuda')
counts = torch.randint(500, 1500, (num_experts,), device='cuda')
counts = (counts.float() / counts.sum() * num_tokens).to(torch.int32)
diff = num_tokens - counts.sum()
counts[0] += diff
bins = torch.cumsum(counts, dim=0, dtype=torch.int32).cuda()
print("Created tensors with shapes:")
print(f"x: {x.shape}, dtype: {x.dtype}")
print(f"indices: {indices.shape}, dtype: {indices.dtype}")
print(f"bins: {bins.shape}, dtype: {bins.dtype}")
print(f"expert_capacity: {expert_capacity}")
# Run the function
output = binned_gather(x, indices, None, bins, expert_capacity, top_k)
print("\nbinned_gather executed successfully.")
print(f"Output shape: {output.shape}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment