Created
August 20, 2025 13:43
-
-
Save 3outeille/66410f907901a5126c78a50d48f2f5ab to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import triton | |
| import triton.language as tl | |
| def assert_is_matrix(x): | |
| if x.ndim != 2: | |
| raise ValueError(f'Expected 2-tensor but got {x.ndim}-tensor') | |
| def assert_is_vector(x): | |
| if x.ndim != 1: | |
| raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') | |
| def assert_equal(a, b): | |
| if a != b: | |
| raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config({'BLOCK_X': 64}, num_warps=2), | |
| triton.Config({'BLOCK_X': 128}, num_warps=2), | |
| triton.Config({'BLOCK_X': 256}, num_warps=2), | |
| triton.Config({'BLOCK_X': 128}, num_warps=4), | |
| triton.Config({'BLOCK_X': 256}, num_warps=4), | |
| ], | |
| key=['NUM_COLUMNS'], | |
| ) | |
| @triton.jit | |
| def _binned_copy( | |
| a, | |
| b, | |
| num_experts, | |
| expert_capacity, | |
| indices, | |
| weights, | |
| bins, | |
| NUM_COLUMNS: tl.constexpr, | |
| TOP_K: tl.constexpr, | |
| BLOCK_X: tl.constexpr, | |
| A_TO_B: tl.constexpr, | |
| SCALE: tl.constexpr, | |
| ): | |
| # Load our indices into the output. | |
| expert_idx = tl.program_id(0) | |
| entry_idx = tl.program_id(1) | |
| # Calculate our offset into the output. | |
| index_b = expert_idx * expert_capacity + entry_idx | |
| # Load the index bounds for our bin and calculate | |
| # the number of tokens assigned to our expert. | |
| start = 0 | |
| if expert_idx > 0: | |
| start = tl.load(bins + expert_idx - 1) | |
| end = tl.load(bins + expert_idx) | |
| num_tokens = end - start | |
| # Calculate our offset into the input. If we don't | |
| # have an input exit early. | |
| if entry_idx >= num_tokens: | |
| return | |
| index_a = tl.load(indices + start + entry_idx) | |
| # Offset the input and output pointers. | |
| # | |
| # If we're going from A to B, divide the input index to copy | |
| # the same input repeatedly. If we're going from B to A we | |
| # need to reduce the result. Using atomics is slow, so we | |
| # do the reduce step in a second kernel. | |
| offset = index_a // TOP_K if A_TO_B else index_a | |
| a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) | |
| b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) | |
| offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) | |
| # Load the scale, if requested. | |
| scale = tl.load(weights + index_a) if SCALE else 1 | |
| # Swap the pointers depending on the direction. | |
| # | |
| # NOTE: We need to zero the output in both directions. | |
| iptr = a if A_TO_B else b | |
| optr = b if A_TO_B else a | |
| iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) | |
| for _ in range(iterations): | |
| mask = offsets < NUM_COLUMNS | |
| x = tl.load(iptr + offsets, mask=mask) | |
| x = x.to(tl.float32) * scale.to(tl.float32) | |
| tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) | |
| offsets += BLOCK_X | |
| def binned_gather(x, indices, weights, bins, expert_capacity, top_k): | |
| # Validate the input shapes. | |
| assert_is_matrix(x) | |
| assert_is_vector(indices) | |
| assert_is_vector(bins) | |
| assert bins.shape[0] > 0, "bins must not be empty" | |
| assert_equal(indices.shape[0], x.shape[0] * top_k) | |
| if weights is not None: | |
| assert_equal(weights.shape[0], x.shape[0] * top_k) | |
| num_experts = bins.shape[0] | |
| out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) | |
| assert expert_capacity > 0, "expert_capacity must be > 0 to launch triton kernel" | |
| _binned_copy[(num_experts, expert_capacity)]( | |
| x, | |
| out, | |
| num_experts, | |
| expert_capacity, | |
| indices, | |
| weights, | |
| bins, | |
| NUM_COLUMNS=x.shape[1], | |
| A_TO_B=True, | |
| TOP_K=top_k, | |
| SCALE=weights is not None, | |
| ) | |
| return out | |
| if __name__ == "__main__": | |
| # Create tensors with shapes and dtypes from the image. | |
| # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications | |
| num_tokens = 15038 | |
| hidden_size = 2880 | |
| num_experts = 1 # <= 2^31-1 | |
| expert_capacity = 65535 # <= 65535 | |
| top_k = 1 | |
| x = torch.randn( | |
| (num_tokens, hidden_size), dtype=torch.bfloat16, device='cuda') | |
| # For this test, we can use a simple range for indices. | |
| indices = torch.arange(num_tokens, dtype=torch.int32, device='cuda') | |
| counts = torch.randint(500, 1500, (num_experts,), device='cuda') | |
| counts = (counts.float() / counts.sum() * num_tokens).to(torch.int32) | |
| diff = num_tokens - counts.sum() | |
| counts[0] += diff | |
| bins = torch.cumsum(counts, dim=0, dtype=torch.int32).cuda() | |
| print("Created tensors with shapes:") | |
| print(f"x: {x.shape}, dtype: {x.dtype}") | |
| print(f"indices: {indices.shape}, dtype: {indices.dtype}") | |
| print(f"bins: {bins.shape}, dtype: {bins.dtype}") | |
| print(f"expert_capacity: {expert_capacity}") | |
| # Run the function | |
| output = binned_gather(x, indices, None, bins, expert_capacity, top_k) | |
| print("\nbinned_gather executed successfully.") | |
| print(f"Output shape: {output.shape}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment