Last active
August 14, 2025 21:34
-
-
Save ita9naiwa/e1409bbbaa146bf721a7e2ec07eebcf1 to your computer and use it in GitHub Desktop.
scaled dot bench
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
os.environ["TRITON_ALWAYS_COMPILE"] = "1" | |
import argparse, torch, triton, triton.language as tl | |
from triton.tools.mxfp import MXFP4Tensor | |
import time | |
def scaleDot_ref(A, B, sA_grouped, sB_grouped, GROUP_K: int): | |
sA = 2 ** (sA_grouped.float() - 127.0) | |
sB = 2 ** (sB_grouped.float() - 127.0) | |
sA_full = sA.repeat_interleave(GROUP_K, dim=1) | |
sB_full = sB.repeat_interleave(GROUP_K, dim=1) | |
tA = A.to(torch.float32) * sA_full | |
tB = B.to(torch.float32) * sB_full.T | |
return torch.matmul(tA, tB) | |
@triton.jit | |
def dot_kernel( | |
A_ptr, B_ptr, C_ptr, | |
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, | |
sA_ptr, sB_ptr, | |
stride_am, stride_ak, | |
stride_bk, stride_bn, | |
stride_cm, stride_cn, | |
stride_meta_a_m: tl.constexpr, | |
stride_meta_a_g: tl.constexpr, | |
stride_meta_b_n: tl.constexpr, | |
stride_meta_b_g: tl.constexpr, | |
BLOCK_M: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
BLOCK_K: tl.constexpr, | |
GROUP_K: tl.constexpr, | |
GROUP: tl.constexpr, | |
atype: tl.constexpr, | |
btype: tl.constexpr, | |
): | |
num_pid_m = tl.cdiv(M, BLOCK_M) | |
num_pid_n = tl.cdiv(N, BLOCK_N) | |
pid = tl.program_id(0) | |
group_size = GROUP | |
num_pid_in_group = group_size * num_pid_n | |
group_id = pid // num_pid_in_group | |
first_pid_m = group_id * group_size | |
group_size_m = tl.minimum(num_pid_m - first_pid_m, group_size) | |
pid_in_group = pid % (group_size_m * num_pid_n) | |
pid_m = first_pid_m + (pid_in_group % group_size_m) | |
pid_n = pid_in_group // group_size_m | |
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32) | |
BLOCK_K_S: tl.constexpr = BLOCK_K // GROUP_K | |
offs_k_scales = tl.arange(0, BLOCK_K_S) | |
for k0 in tl.static_range(0, K, BLOCK_K): | |
offs_k = k0 + tl.arange(0, BLOCK_K) | |
a_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) | |
b_mask = ((offs_k[:, None] < K) & (offs_n[None, :] < N)).to(tl.int1) | |
A = tl.load(A_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak, mask=a_mask, other=0.) | |
B = tl.load(B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn, mask=b_mask, other=0.) | |
k_group_index_base = k0 // GROUP_K | |
sA_blk = tl.load( | |
sA_ptr + offs_m[:, None] * stride_meta_a_m + (k_group_index_base + offs_k_scales[None, :]) * stride_meta_a_g | |
) | |
sB_blk = tl.load( | |
sB_ptr + offs_n[:, None] * stride_meta_b_n + (k_group_index_base + offs_k_scales[None, :]) * stride_meta_b_g | |
) | |
acc = tl.dot_scaled(A, sA_blk, atype, B, sB_blk, btype, acc) | |
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | |
tl.store(C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, acc, mask=c_mask) | |
def bench(M, N, K, atype="e4m3", btype="e4m3", BLOCK_M=16, BLOCK_N=8, BLOCK_K=32, GROUP_K=32, GROUP=8, NUM_WARPS=4, warmup=10, iters=10000): | |
# assert (M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0), \ | |
# f"dims must be multiples of {BLOCK_M}×{BLOCK_N}×{BLOCK_K}" | |
# assert K % GROUP_K == 0, "K must be divisible by GROUP_K" | |
# assert BLOCK_K % GROUP_K == 0, "BLOCK_K must be divisible by GROUP_K" | |
if atype == "e4m3": | |
_atype = torch.float8_e4m3fn | |
elif atype == "fp16": | |
_atype = torch.float16 | |
elif atype == "bf16": | |
_atype = torch.bfloat16 | |
elif atype == "e5m2": | |
_atype = torch.float8_e5m2 | |
if btype == "e4m3": | |
_btype = torch.float8_e4m3fn | |
elif btype == "e5m2": | |
_btype = torch.float8_e5m2 | |
torch.manual_seed(42) | |
A_float = torch.normal(0, 1, (M, K), device="cuda") / 10 | |
B_float = torch.normal(0, 1, (K, N), device="cuda") / 10 | |
A = A_float.to(_atype) | |
B = B_float.to(_btype) | |
C = torch.zeros(M, N, device="cuda", dtype=torch.float32) | |
KG = K // GROUP_K | |
sB = torch.randint(120, 130, (N, KG), device='cuda', dtype=torch.uint8) | |
if(_atype in [torch.bfloat16, torch.float16]): | |
sA = torch.ones((M, KG), device='cuda', dtype=torch.uint8) * 127 | |
else: | |
sA = torch.randint(120, 130, (M, KG), device='cuda', dtype=torch.uint8) | |
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) | |
compiled = dot_kernel[grid]( | |
A, B, C, M, N, K, | |
sA, sB, | |
A.stride(0), A.stride(1), | |
B.stride(0), B.stride(1), | |
C.stride(0), C.stride(1), | |
sA.stride(0), sA.stride(1), | |
sB.stride(0), sB.stride(1), | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_K=GROUP_K, GROUP=GROUP, | |
atype=atype, btype=btype, num_warps=NUM_WARPS, num_stages=4) | |
C1 = scaleDot_ref(A, B, sA, sB, GROUP_K) | |
if torch.allclose(C, C1, atol=1e-3, rtol=1e-3): | |
correctness = "✓ Pass" | |
else: | |
abs_diff = torch.abs(C - C1) | |
max_diff = torch.max(abs_diff) | |
correctness = f"✗ Fail (max diff: {max_diff:.6f})" | |
for _ in range(warmup): | |
C.zero_() | |
compiled = dot_kernel[grid]( | |
A, B, C, M, N, K, | |
sA, sB, | |
A.stride(0), A.stride(1), | |
B.stride(0), B.stride(1), | |
C.stride(0), C.stride(1), | |
sA.stride(0), sA.stride(1), | |
sB.stride(0), sB.stride(1), | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_K=GROUP_K, GROUP=GROUP, | |
atype=atype, btype=btype, num_warps=NUM_WARPS, num_stages=4) | |
torch.cuda.synchronize() | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
for _ in range(iters): | |
C.zero_() | |
compiled = dot_kernel[grid]( | |
A, B, C, M, N, K, | |
sA, sB, | |
A.stride(0), A.stride(1), | |
B.stride(0), B.stride(1), | |
C.stride(0), C.stride(1), | |
sA.stride(0), sA.stride(1), | |
sB.stride(0), sB.stride(1), | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_K=GROUP_K, GROUP=GROUP, | |
atype=atype, btype=btype, num_warps=NUM_WARPS, num_stages=4) | |
end_event.record() | |
torch.cuda.synchronize() | |
elapsed_ms = start_event.elapsed_time(end_event) | |
avg_time_ms = elapsed_ms / iters | |
flops = 2 * M * N * K | |
tflops = (flops / (avg_time_ms * 1e-3)) / 1e12 | |
print(f" {correctness} | Time: {avg_time_ms:.3f} ms | TFLOPS: {tflops:.2f}") | |
with open("dot_kernel.ptx", "w") as f: | |
f.write(compiled.asm["ptx"]) | |
return avg_time_ms, tflops | |
################################################################################################################# | |
#m16n8k32 | |
if __name__ == "__main__": | |
M, K, N = 512, 512, 1024 | |
#a_type, b_type = "e4m3", "e4m3" | |
a_type, b_type = "e4m3", "e4m3" | |
print(f"Matrix dimensions: M={M}, K={K}, N={N}") | |
print(f"Data types: A={a_type}, B={b_type}") | |
print(f"Warmup: 10 iterations, Benchmark: 1000 iterations") | |
print("-" * 80) | |
results = [] | |
for BLOCK_M in [64, 128]: #<=16 | |
for BLOCK_N in [64, 128, 256]: | |
for BLOCK_K in [64, 128, 256]: | |
for GROUP_K in [32]: | |
if (BLOCK_K < GROUP_K) or (BLOCK_M > M) or (BLOCK_K > K) or (BLOCK_N > N): | |
continue | |
if BLOCK_K % GROUP_K != 0: | |
continue | |
config_str = f"BLOCK_M={BLOCK_M:3d}, BLOCK_N={BLOCK_N:3d}, BLOCK_K={BLOCK_K:3d}, GROUP_K={GROUP_K}" | |
print(f"{config_str}", end=" ") | |
try: | |
time_ms, tflops = bench(M, N, K, a_type, b_type, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_K, | |
warmup=10, iters=10000) | |
results.append({ | |
'BLOCK_M': BLOCK_M, | |
'BLOCK_N': BLOCK_N, | |
'BLOCK_K': BLOCK_K, | |
'GROUP_K': GROUP_K, | |
'time_ms': time_ms, | |
'tflops': tflops | |
}) | |
except Exception as e: | |
print(f" ✗ Error: {e}") | |
if results: | |
print("\n" + "="*80) | |
print("Top 5 configurations by performance:") | |
print("-" * 80) | |
sorted_results = sorted(results, key=lambda x: x['tflops'], reverse=True) | |
for i, r in enumerate(sorted_results[:5], 1): | |
print(f"{i}. BLOCK_M={r['BLOCK_M']:3d}, BLOCK_N={r['BLOCK_N']:3d}, " | |
f"BLOCK_K={r['BLOCK_K']:3d}, GROUP_K={r['GROUP_K']} " | |
f"-> {r['time_ms']:.3f} ms, {r['tflops']:.2f} TFLOPS") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment