Created
May 8, 2025 14:58
-
-
Save drbh/c754a4ba52bcc46190ae4a45516fb190 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# dependencies = [ | |
# "numpy", | |
# "torch", | |
# "kernels", | |
# ] | |
# /// | |
import torch | |
# reuse the models from the previous snippets or copy the class | |
# definitions here to run this script independently | |
from rmsnorm_baseline import BaselineModel | |
from rmsnorm_kernel import KernelModel | |
DEVICE = "cuda" | |
DTYPE = torch.float16 # Use float16 for better kernel performance potential | |
# Use torch.cuda.Event for accurate GPU timing (ensure function is defined) | |
def benchmark_model(model, input_tensor, num_runs=100, warmup_runs=10): | |
model.eval() # Set model to evaluation mode | |
dtype = input_tensor.dtype | |
model = model.to(input_tensor.device).to(dtype) | |
# Warmup runs | |
for _ in range(warmup_runs): | |
_ = model(input_tensor) | |
torch.cuda.synchronize() | |
# Timed runs | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
for _ in range(num_runs): | |
_ = model(input_tensor) | |
end_event.record() | |
torch.cuda.synchronize() | |
elapsed_time_ms = start_event.elapsed_time(end_event) | |
avg_time_ms = elapsed_time_ms / num_runs | |
return avg_time_ms | |
input_size_bench = 4096 | |
hidden_size_bench = 4096 # RMSNorm performance is sensitive to this dimension | |
output_size_bench = 10 | |
eps_val_bench = 1e-5 | |
# Create larger models and input for benchmark | |
# Ensure both models are fully converted to the target DEVICE and DTYPE | |
baseline_model_bench = ( | |
BaselineModel( | |
input_size_bench, hidden_size_bench, output_size_bench, eps=eps_val_bench | |
) | |
.to(DEVICE) | |
.to(DTYPE) | |
) | |
kernel_model_bench = ( | |
KernelModel( | |
input_size_bench, | |
hidden_size_bench, | |
output_size_bench, | |
device=DEVICE, | |
dtype=DTYPE, | |
eps=eps_val_bench, | |
) | |
.to(DEVICE) | |
.to(DTYPE) | |
) | |
# call both with larger batch sizes to warm up the GPU | |
# and ensure the models are loaded | |
warmup_input = torch.randn(4096, input_size_bench, device=DEVICE, dtype=DTYPE) | |
_ = kernel_model_bench(warmup_input) | |
_ = baseline_model_bench(warmup_input) | |
batch_sizes = [ | |
256, | |
512, | |
1024, | |
2048, | |
4096, | |
8192, | |
16384, | |
32768, | |
] | |
print( | |
f"{'Batch Size':<12} | {'Baseline Time (ms)':<18} | {'Kernel Time (ms)':<18} | {'Speedup'}" | |
) | |
print("-" * 74) | |
for batch_size in batch_sizes: | |
# Call cuda synchronize to ensure all previous GPU operations are complete | |
torch.cuda.synchronize() | |
# Create random input tensor | |
# Ensure the input tensor is on the correct device and dtype | |
bench_input = torch.randn(batch_size, input_size_bench, device=DEVICE, dtype=DTYPE) | |
# Run benchmarks only if kernel was loaded successfully | |
baseline_time = benchmark_model(baseline_model_bench, bench_input) | |
kernel_time = -1 # Sentinel value | |
kernel_time = benchmark_model(kernel_model_bench, bench_input) | |
baseline_time = round(baseline_time, 4) | |
kernel_time = round(kernel_time, 4) | |
speedup = round(baseline_time / kernel_time, 2) if kernel_time > 0 else "N/A" | |
if kernel_time < baseline_time: | |
speedup = f"{speedup:.2f}x" | |
elif kernel_time == baseline_time: | |
speedup = "1.00x (identical)" | |
else: | |
speedup = f"{kernel_time / baseline_time:.2f}x slower" | |
print(f"{batch_size:<12} | {baseline_time:<18} | {kernel_time:<18} | {speedup}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment