Created
May 8, 2025 14:55
-
-
Save drbh/aa4b8cfb79597e98be6cf0108644ce16 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# dependencies = [ | |
# "numpy", | |
# "torch", | |
# "kernels", | |
# ] | |
# /// | |
import torch | |
import torch.nn.functional as F | |
from kernels import get_kernel | |
DEVICE = "cuda" | |
# Make reproducible | |
torch.manual_seed(42) | |
# Download optimized activation kernels from the Hub | |
activation_kernels = get_kernel("kernels-community/activation") | |
# Create a random tensor on the GPU | |
x = torch.randn((4, 4), dtype=torch.float16, device=DEVICE) | |
# Prepare an output tensor | |
y = torch.empty_like(x) | |
# Run the fast GELU kernel | |
activation_kernels.gelu_fast(y, x) | |
# Get expected output using PyTorch's built-in GELU | |
expected = F.gelu(x) | |
# Compare the kernel output with PyTorch's result | |
torch.testing.assert_close(y, expected, rtol=1e-2, atol=1e-2) | |
print("✅ Kernel output matches PyTorch GELU!") | |
# Optional: print both tensors for inspection | |
print("\nInput tensor:") | |
print(x) | |
print("\nFast GELU kernel output:") | |
print(y) | |
print("\nPyTorch GELU output:") | |
print(expected) | |
# List available functions in the loaded kernel module | |
print("\nAvailable functions in 'kernels-community/activation':") | |
print(dir(activation_kernels)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment