Skip to content

Instantly share code, notes, and snippets.

@knwng
Created July 2, 2025 22:54
Show Gist options
  • Save knwng/6aaf6d03bbae2709fdff61401b27223f to your computer and use it in GitHub Desktop.
Save knwng/6aaf6d03bbae2709fdff61401b27223f to your computer and use it in GitHub Desktop.
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
# This kernel is a trimmed version of triton/python/tutorials/03-matrix-multiplication.py
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_ak = tl.arange(0, BLOCK_SIZE_K)
offs_bk = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_ak[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_bk[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot_scaled(a, None, "e5m2", b, None, "e4m3", accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul(a, b):
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float8_e5m2)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
# BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=128,
BLOCK_SIZE_M=16, BLOCK_SIZE_N=16, BLOCK_SIZE_K=128,
GROUP_SIZE_M=1
)
return c
def matmul_torch(a, b):
return torch.matmul(a, b).to(torch.float8_e5m2)
def alloc_rand(shape, device, dtype, requires_grad=True):
if dtype.itemsize == 1:
tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16))
return tmp.to(dtype).requires_grad_(requires_grad)
return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)
# torch.manual_seed(0)
FACTOR_B = 10
# M = 16
# N = 16
M = 256
N = 256
K = 256
a = alloc_rand((M, K), DEVICE, torch.float8_e5m2)
# b = alloc_rand((K, N), DEVICE, torch.float8_e4m3fn).to(torch.float16) * FACTOR_B
b = alloc_rand((K, N), DEVICE, torch.float16) * FACTOR_B
b = b.to(torch.float8_e4m3fn)
# print(f'{a=}')
# print(f'{b=}')
tri = matmul(a, b)
a_ref = a.clone().float()
b_ref = b.clone().float()
ref = matmul_torch(a_ref, b_ref)
# print(f'{tri=}')
# print(f'{ref=}')
tri = tri.to(torch.float32)
ref = ref.to(torch.float32)
# mismatches = torch.nonzero(tri != ref)
# for i, j in mismatches:
# print(f'({i}, {j}): {tri[i, j]} vs {ref[i, j]}')
# print(f'{a[i, :]=}')
# print(f'{a_ref[i, :]=}')
# print(f'{torch.nonzero(a[i, :].float() != a_ref[i, :])}')
# print(f'{b[:, j]=}')
# print(f'{b_ref[:, j]=}')
# print(f'{torch.nonzero(b[:, j].float() != b_ref[:, j])}')
torch.testing.assert_close(tri, ref)
print(f'✅Pass')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment