Skip to content

Instantly share code, notes, and snippets.

@drbh
Created May 8, 2025 14:58
Show Gist options
  • Save drbh/c754a4ba52bcc46190ae4a45516fb190 to your computer and use it in GitHub Desktop.
Save drbh/c754a4ba52bcc46190ae4a45516fb190 to your computer and use it in GitHub Desktop.
# /// script
# dependencies = [
# "numpy",
# "torch",
# "kernels",
# ]
# ///
import torch
# reuse the models from the previous snippets or copy the class
# definitions here to run this script independently
from rmsnorm_baseline import BaselineModel
from rmsnorm_kernel import KernelModel
DEVICE = "cuda"
DTYPE = torch.float16 # Use float16 for better kernel performance potential
# Use torch.cuda.Event for accurate GPU timing (ensure function is defined)
def benchmark_model(model, input_tensor, num_runs=100, warmup_runs=10):
model.eval() # Set model to evaluation mode
dtype = input_tensor.dtype
model = model.to(input_tensor.device).to(dtype)
# Warmup runs
for _ in range(warmup_runs):
_ = model(input_tensor)
torch.cuda.synchronize()
# Timed runs
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_runs):
_ = model(input_tensor)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time_ms = elapsed_time_ms / num_runs
return avg_time_ms
input_size_bench = 4096
hidden_size_bench = 4096 # RMSNorm performance is sensitive to this dimension
output_size_bench = 10
eps_val_bench = 1e-5
# Create larger models and input for benchmark
# Ensure both models are fully converted to the target DEVICE and DTYPE
baseline_model_bench = (
BaselineModel(
input_size_bench, hidden_size_bench, output_size_bench, eps=eps_val_bench
)
.to(DEVICE)
.to(DTYPE)
)
kernel_model_bench = (
KernelModel(
input_size_bench,
hidden_size_bench,
output_size_bench,
device=DEVICE,
dtype=DTYPE,
eps=eps_val_bench,
)
.to(DEVICE)
.to(DTYPE)
)
# call both with larger batch sizes to warm up the GPU
# and ensure the models are loaded
warmup_input = torch.randn(4096, input_size_bench, device=DEVICE, dtype=DTYPE)
_ = kernel_model_bench(warmup_input)
_ = baseline_model_bench(warmup_input)
batch_sizes = [
256,
512,
1024,
2048,
4096,
8192,
16384,
32768,
]
print(
f"{'Batch Size':<12} | {'Baseline Time (ms)':<18} | {'Kernel Time (ms)':<18} | {'Speedup'}"
)
print("-" * 74)
for batch_size in batch_sizes:
# Call cuda synchronize to ensure all previous GPU operations are complete
torch.cuda.synchronize()
# Create random input tensor
# Ensure the input tensor is on the correct device and dtype
bench_input = torch.randn(batch_size, input_size_bench, device=DEVICE, dtype=DTYPE)
# Run benchmarks only if kernel was loaded successfully
baseline_time = benchmark_model(baseline_model_bench, bench_input)
kernel_time = -1 # Sentinel value
kernel_time = benchmark_model(kernel_model_bench, bench_input)
baseline_time = round(baseline_time, 4)
kernel_time = round(kernel_time, 4)
speedup = round(baseline_time / kernel_time, 2) if kernel_time > 0 else "N/A"
if kernel_time < baseline_time:
speedup = f"{speedup:.2f}x"
elif kernel_time == baseline_time:
speedup = "1.00x (identical)"
else:
speedup = f"{kernel_time / baseline_time:.2f}x slower"
print(f"{batch_size:<12} | {baseline_time:<18} | {kernel_time:<18} | {speedup}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment