Skip to content

Instantly share code, notes, and snippets.

@ita9naiwa
Last active August 14, 2025 21:34
Show Gist options
  • Save ita9naiwa/e1409bbbaa146bf721a7e2ec07eebcf1 to your computer and use it in GitHub Desktop.
Save ita9naiwa/e1409bbbaa146bf721a7e2ec07eebcf1 to your computer and use it in GitHub Desktop.
scaled dot bench
import os
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TRITON_ALWAYS_COMPILE"] = "1"
import argparse, torch, triton, triton.language as tl
from triton.tools.mxfp import MXFP4Tensor
import time
def scaleDot_ref(A, B, sA_grouped, sB_grouped, GROUP_K: int):
sA = 2 ** (sA_grouped.float() - 127.0)
sB = 2 ** (sB_grouped.float() - 127.0)
sA_full = sA.repeat_interleave(GROUP_K, dim=1)
sB_full = sB.repeat_interleave(GROUP_K, dim=1)
tA = A.to(torch.float32) * sA_full
tB = B.to(torch.float32) * sB_full.T
return torch.matmul(tA, tB)
@triton.jit
def dot_kernel(
A_ptr, B_ptr, C_ptr,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
sA_ptr, sB_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
stride_meta_a_m: tl.constexpr,
stride_meta_a_g: tl.constexpr,
stride_meta_b_n: tl.constexpr,
stride_meta_b_g: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_K: tl.constexpr,
GROUP: tl.constexpr,
atype: tl.constexpr,
btype: tl.constexpr,
):
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid = tl.program_id(0)
group_size = GROUP
num_pid_in_group = group_size * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * group_size
group_size_m = tl.minimum(num_pid_m - first_pid_m, group_size)
pid_in_group = pid % (group_size_m * num_pid_n)
pid_m = first_pid_m + (pid_in_group % group_size_m)
pid_n = pid_in_group // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32)
BLOCK_K_S: tl.constexpr = BLOCK_K // GROUP_K
offs_k_scales = tl.arange(0, BLOCK_K_S)
for k0 in tl.static_range(0, K, BLOCK_K):
offs_k = k0 + tl.arange(0, BLOCK_K)
a_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1)
b_mask = ((offs_k[:, None] < K) & (offs_n[None, :] < N)).to(tl.int1)
A = tl.load(A_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak, mask=a_mask, other=0.)
B = tl.load(B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn, mask=b_mask, other=0.)
k_group_index_base = k0 // GROUP_K
sA_blk = tl.load(
sA_ptr + offs_m[:, None] * stride_meta_a_m + (k_group_index_base + offs_k_scales[None, :]) * stride_meta_a_g
)
sB_blk = tl.load(
sB_ptr + offs_n[:, None] * stride_meta_b_n + (k_group_index_base + offs_k_scales[None, :]) * stride_meta_b_g
)
acc = tl.dot_scaled(A, sA_blk, atype, B, sB_blk, btype, acc)
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, acc, mask=c_mask)
def bench(M, N, K, atype="e4m3", btype="e4m3", BLOCK_M=16, BLOCK_N=8, BLOCK_K=32, GROUP_K=32, GROUP=8, NUM_WARPS=4, warmup=10, iters=10000):
# assert (M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0), \
# f"dims must be multiples of {BLOCK_M}×{BLOCK_N}×{BLOCK_K}"
# assert K % GROUP_K == 0, "K must be divisible by GROUP_K"
# assert BLOCK_K % GROUP_K == 0, "BLOCK_K must be divisible by GROUP_K"
if atype == "e4m3":
_atype = torch.float8_e4m3fn
elif atype == "fp16":
_atype = torch.float16
elif atype == "bf16":
_atype = torch.bfloat16
elif atype == "e5m2":
_atype = torch.float8_e5m2
if btype == "e4m3":
_btype = torch.float8_e4m3fn
elif btype == "e5m2":
_btype = torch.float8_e5m2
torch.manual_seed(42)
A_float = torch.normal(0, 1, (M, K), device="cuda") / 10
B_float = torch.normal(0, 1, (K, N), device="cuda") / 10
A = A_float.to(_atype)
B = B_float.to(_btype)
C = torch.zeros(M, N, device="cuda", dtype=torch.float32)
KG = K // GROUP_K
sB = torch.randint(120, 130, (N, KG), device='cuda', dtype=torch.uint8)
if(_atype in [torch.bfloat16, torch.float16]):
sA = torch.ones((M, KG), device='cuda', dtype=torch.uint8) * 127
else:
sA = torch.randint(120, 130, (M, KG), device='cuda', dtype=torch.uint8)
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
compiled = dot_kernel[grid](
A, B, C, M, N, K,
sA, sB,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
sA.stride(0), sA.stride(1),
sB.stride(0), sB.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_K=GROUP_K, GROUP=GROUP,
atype=atype, btype=btype, num_warps=NUM_WARPS, num_stages=4)
C1 = scaleDot_ref(A, B, sA, sB, GROUP_K)
if torch.allclose(C, C1, atol=1e-3, rtol=1e-3):
correctness = "✓ Pass"
else:
abs_diff = torch.abs(C - C1)
max_diff = torch.max(abs_diff)
correctness = f"✗ Fail (max diff: {max_diff:.6f})"
for _ in range(warmup):
C.zero_()
compiled = dot_kernel[grid](
A, B, C, M, N, K,
sA, sB,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
sA.stride(0), sA.stride(1),
sB.stride(0), sB.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_K=GROUP_K, GROUP=GROUP,
atype=atype, btype=btype, num_warps=NUM_WARPS, num_stages=4)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(iters):
C.zero_()
compiled = dot_kernel[grid](
A, B, C, M, N, K,
sA, sB,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
sA.stride(0), sA.stride(1),
sB.stride(0), sB.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_K=GROUP_K, GROUP=GROUP,
atype=atype, btype=btype, num_warps=NUM_WARPS, num_stages=4)
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
avg_time_ms = elapsed_ms / iters
flops = 2 * M * N * K
tflops = (flops / (avg_time_ms * 1e-3)) / 1e12
print(f" {correctness} | Time: {avg_time_ms:.3f} ms | TFLOPS: {tflops:.2f}")
with open("dot_kernel.ptx", "w") as f:
f.write(compiled.asm["ptx"])
return avg_time_ms, tflops
#################################################################################################################
#m16n8k32
if __name__ == "__main__":
M, K, N = 512, 512, 1024
#a_type, b_type = "e4m3", "e4m3"
a_type, b_type = "e4m3", "e4m3"
print(f"Matrix dimensions: M={M}, K={K}, N={N}")
print(f"Data types: A={a_type}, B={b_type}")
print(f"Warmup: 10 iterations, Benchmark: 1000 iterations")
print("-" * 80)
results = []
for BLOCK_M in [64, 128]: #<=16
for BLOCK_N in [64, 128, 256]:
for BLOCK_K in [64, 128, 256]:
for GROUP_K in [32]:
if (BLOCK_K < GROUP_K) or (BLOCK_M > M) or (BLOCK_K > K) or (BLOCK_N > N):
continue
if BLOCK_K % GROUP_K != 0:
continue
config_str = f"BLOCK_M={BLOCK_M:3d}, BLOCK_N={BLOCK_N:3d}, BLOCK_K={BLOCK_K:3d}, GROUP_K={GROUP_K}"
print(f"{config_str}", end=" ")
try:
time_ms, tflops = bench(M, N, K, a_type, b_type, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_K,
warmup=10, iters=10000)
results.append({
'BLOCK_M': BLOCK_M,
'BLOCK_N': BLOCK_N,
'BLOCK_K': BLOCK_K,
'GROUP_K': GROUP_K,
'time_ms': time_ms,
'tflops': tflops
})
except Exception as e:
print(f" ✗ Error: {e}")
if results:
print("\n" + "="*80)
print("Top 5 configurations by performance:")
print("-" * 80)
sorted_results = sorted(results, key=lambda x: x['tflops'], reverse=True)
for i, r in enumerate(sorted_results[:5], 1):
print(f"{i}. BLOCK_M={r['BLOCK_M']:3d}, BLOCK_N={r['BLOCK_N']:3d}, "
f"BLOCK_K={r['BLOCK_K']:3d}, GROUP_K={r['GROUP_K']} "
f"-> {r['time_ms']:.3f} ms, {r['tflops']:.2f} TFLOPS")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment