|
import pytest |
|
import torch |
|
import numpy as np |
|
import deep_gemm |
|
import os |
|
from typing import Tuple, List |
|
|
|
|
|
class TestDeepGEMM: |
|
""" |
|
Comprehensive test suite for DeepGEMM FP8 GEMM operations. |
|
Tests environment-dependent behavior and numerical consistency. |
|
""" |
|
|
|
@pytest.fixture(scope="class") |
|
def setup_gpu(self): |
|
"""Setup GPU device and check Hopper support""" |
|
if not torch.cuda.is_available(): |
|
pytest.skip("CUDA not available") |
|
|
|
device = torch.cuda.current_device() |
|
device_props = torch.cuda.get_device_properties(device) |
|
|
|
# DeepGEMM requires Hopper architecture (sm_90a) |
|
if device_props.major < 9: |
|
pytest.skip(f"DeepGEMM requires Hopper GPU (sm_90+), got sm_{device_props.major}{device_props.minor}") |
|
|
|
return device |
|
|
|
@pytest.fixture |
|
def random_seed(self): |
|
"""Set random seed for reproducibility""" |
|
torch.manual_seed(42) |
|
np.random.seed(42) |
|
return 42 |
|
|
|
def create_fp8_tensor(self, shape: Tuple[int, ...], dtype=torch.float8_e4m3fn) -> torch.Tensor: |
|
"""Create a random FP8 tensor with proper scaling""" |
|
# Create random data in BF16 first, then cast to FP8 |
|
data = torch.randn(*shape, dtype=torch.bfloat16, device='cuda') |
|
# Scale to reasonable FP8 range |
|
data = data * 0.1 |
|
return data.to(dtype) |
|
|
|
def create_scaling_factor(self, shape: Tuple[int, ...]) -> torch.Tensor: |
|
"""Create scaling factors for FP8 computations""" |
|
# Scaling factors should be positive and reasonable |
|
return torch.ones(*shape, dtype=torch.float32, device='cuda') * 0.125 |
|
|
|
def get_reference_result(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
"""Compute reference result using PyTorch's built-in GEMM""" |
|
# Convert FP8 to BF16 for reference computation |
|
x_bf16 = x.to(torch.bfloat16) |
|
y_bf16 = y.to(torch.bfloat16) |
|
return torch.matmul(x_bf16, y_bf16.T) |
|
|
|
@pytest.mark.parametrize("m,n,k", [ |
|
(128, 128, 128), # Small square matrices |
|
(256, 256, 256), # Medium square matrices |
|
(512, 128, 256), # Rectangular matrices |
|
(1024, 512, 1024), # Larger matrices |
|
(2048, 1024, 2048), # Large matrices (if memory allows) |
|
]) |
|
def test_basic_gemm_fp8(self, setup_gpu, random_seed, m, n, k): |
|
"""Test basic FP8 GEMM operation with various matrix sizes""" |
|
device = setup_gpu |
|
|
|
# Create input tensors |
|
x = self.create_fp8_tensor((m, k)) |
|
y = self.create_fp8_tensor((n, k)) # Note: will be transposed in GEMM |
|
|
|
# Create scaling factors (required for FP8) |
|
x_scale = self.create_scaling_factor((1,)) |
|
y_scale = self.create_scaling_factor((1,)) |
|
|
|
# Output tensor |
|
out = torch.empty((m, n), dtype=torch.bfloat16, device='cuda') |
|
|
|
# Call DeepGEMM |
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x, y, out, x_scale, y_scale) |
|
|
|
# Verify output shape |
|
assert out.shape == (m, n), f"Expected shape ({m}, {n}), got {out.shape}" |
|
|
|
# Check for NaN/Inf values |
|
assert torch.isfinite(out).all(), "Output contains NaN or Inf values" |
|
|
|
# Basic sanity check - output should be reasonable |
|
assert out.abs().mean() < 100.0, "Output values seem too large" |
|
|
|
print(f"✓ Basic GEMM test passed for shape ({m}, {n}, {k})") |
|
print(f" Output range: [{out.min().item():.6f}, {out.max().item():.6f}]") |
|
|
|
@pytest.mark.parametrize("alignment", [64, 128, 256]) |
|
def test_m_alignment_requirement(self, setup_gpu, alignment): |
|
"""Test M-axis alignment requirements for grouped operations""" |
|
device = setup_gpu |
|
|
|
# Get required alignment |
|
required_alignment = deep_gemm.get_m_alignment_for_contiguous_layout() |
|
|
|
# Test with properly aligned M dimension |
|
m_aligned = ((1000 // required_alignment) + 1) * required_alignment |
|
n, k = 512, 256 |
|
|
|
x = self.create_fp8_tensor((m_aligned, k)) |
|
y = self.create_fp8_tensor((n, k)) |
|
|
|
x_scale = self.create_scaling_factor((1,)) |
|
y_scale = self.create_scaling_factor((1,)) |
|
|
|
out = torch.empty((m_aligned, n), dtype=torch.bfloat16, device='cuda') |
|
|
|
# This should work without issues |
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x, y, out, x_scale, y_scale) |
|
|
|
assert torch.isfinite(out).all() |
|
print(f"✓ M-alignment test passed with alignment {required_alignment}") |
|
|
|
def test_grouped_gemm_contiguous(self, setup_gpu, random_seed): |
|
"""Test grouped GEMM with contiguous layout for MoE models""" |
|
device = setup_gpu |
|
|
|
# Parameters for grouped GEMM |
|
num_experts = 4 |
|
tokens_per_expert = [64, 128, 96, 112] # Variable tokens per expert |
|
total_tokens = sum(tokens_per_expert) |
|
n, k = 256, 512 |
|
|
|
# Ensure alignment |
|
alignment = deep_gemm.get_m_alignment_for_contiguous_layout() |
|
aligned_tokens = [((t // alignment) + 1) * alignment for t in tokens_per_expert] |
|
total_aligned = sum(aligned_tokens) |
|
|
|
# Create inputs |
|
x = self.create_fp8_tensor((total_aligned, k)) |
|
y = self.create_fp8_tensor((num_experts, n, k)) # One weight matrix per expert |
|
|
|
# Scaling factors |
|
x_scale = self.create_scaling_factor((total_aligned // alignment,)) |
|
y_scale = self.create_scaling_factor((num_experts,)) |
|
|
|
# Output tensor |
|
out = torch.empty((total_aligned, n), dtype=torch.bfloat16, device='cuda') |
|
|
|
# Group offsets for contiguous layout |
|
group_offsets = torch.tensor([0] + aligned_tokens[:-1], dtype=torch.int32, device='cuda') |
|
group_offsets = torch.cumsum(group_offsets, dim=0) |
|
|
|
# Call grouped GEMM |
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( |
|
x, y, out, x_scale, y_scale, group_offsets |
|
) |
|
|
|
# Verify output |
|
assert out.shape == (total_aligned, n) |
|
assert torch.isfinite(out).all() |
|
|
|
print(f"✓ Grouped contiguous GEMM test passed") |
|
print(f" Experts: {num_experts}, Total tokens: {total_aligned}") |
|
|
|
def test_grouped_gemm_masked(self, setup_gpu, random_seed): |
|
"""Test masked grouped GEMM for inference scenarios""" |
|
device = setup_gpu |
|
|
|
# Parameters |
|
num_experts = 8 |
|
max_tokens_per_expert = 128 |
|
n, k = 512, 256 |
|
|
|
# Create inputs |
|
x = self.create_fp8_tensor((num_experts, max_tokens_per_expert, k)) |
|
y = self.create_fp8_tensor((num_experts, n, k)) |
|
|
|
# Create mask (some experts process fewer tokens) |
|
actual_tokens = [32, 64, 128, 96, 48, 80, 112, 16] |
|
mask = torch.zeros((num_experts, max_tokens_per_expert), dtype=torch.bool, device='cuda') |
|
for i, tokens in enumerate(actual_tokens): |
|
mask[i, :tokens] = True |
|
|
|
# Scaling factors |
|
x_scale = self.create_scaling_factor((num_experts,)) |
|
y_scale = self.create_scaling_factor((num_experts,)) |
|
|
|
# Output tensor |
|
out = torch.empty((num_experts, max_tokens_per_expert, n), dtype=torch.bfloat16, device='cuda') |
|
|
|
# Call masked grouped GEMM |
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( |
|
x, y, out, x_scale, y_scale, mask |
|
) |
|
|
|
# Verify output |
|
assert out.shape == (num_experts, max_tokens_per_expert, n) |
|
assert torch.isfinite(out).all() |
|
|
|
# Check that masked regions are properly handled |
|
for i, tokens in enumerate(actual_tokens): |
|
valid_out = out[i, :tokens, :] |
|
assert torch.isfinite(valid_out).all() |
|
|
|
print(f"✓ Masked grouped GEMM test passed") |
|
print(f" Experts: {num_experts}, Max tokens per expert: {max_tokens_per_expert}") |
|
|
|
def test_numerical_consistency(self, setup_gpu, random_seed): |
|
"""Test numerical consistency across multiple runs""" |
|
device = setup_gpu |
|
m, n, k = 512, 256, 512 |
|
|
|
# Create fixed inputs |
|
x = self.create_fp8_tensor((m, k)) |
|
y = self.create_fp8_tensor((n, k)) |
|
x_scale = self.create_scaling_factor((1,)) |
|
y_scale = self.create_scaling_factor((1,)) |
|
|
|
results = [] |
|
|
|
# Run multiple times |
|
for i in range(3): |
|
out = torch.empty((m, n), dtype=torch.bfloat16, device='cuda') |
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x, y, out, x_scale, y_scale) |
|
results.append(out.clone().cpu()) |
|
|
|
# Check consistency |
|
for i in range(1, len(results)): |
|
diff = torch.abs(results[0] - results[i]) |
|
max_diff = diff.max().item() |
|
assert max_diff < 1e-3, f"Results not consistent across runs: max_diff={max_diff}" |
|
|
|
print("✓ Numerical consistency test passed") |
|
|
|
def test_environment_variables(self, setup_gpu): |
|
"""Test that environment variables affect behavior""" |
|
device = setup_gpu |
|
|
|
# Test with JIT debug enabled |
|
original_debug = os.environ.get('DG_JIT_DEBUG', '0') |
|
os.environ['DG_JIT_DEBUG'] = '1' |
|
|
|
try: |
|
m, n, k = 256, 128, 256 |
|
x = self.create_fp8_tensor((m, k)) |
|
y = self.create_fp8_tensor((n, k)) |
|
x_scale = self.create_scaling_factor((1,)) |
|
y_scale = self.create_scaling_factor((1,)) |
|
out = torch.empty((m, n), dtype=torch.bfloat16, device='cuda') |
|
|
|
# This should work with debug enabled |
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x, y, out, x_scale, y_scale) |
|
assert torch.isfinite(out).all() |
|
|
|
finally: |
|
os.environ['DG_JIT_DEBUG'] = original_debug |
|
|
|
print("✓ Environment variable test passed") |
|
|
|
def test_sm_count_utility(self, setup_gpu): |
|
"""Test SM count utility functions""" |
|
device = setup_gpu |
|
|
|
# Get current SM count |
|
current_sms = deep_gemm.get_num_sms() |
|
assert current_sms > 0, "Should have positive SM count" |
|
|
|
# Test setting SM count |
|
original_sms = current_sms |
|
test_sms = min(64, current_sms) # Use a reasonable number |
|
|
|
deep_gemm.set_num_sms(test_sms) |
|
new_sms = deep_gemm.get_num_sms() |
|
assert new_sms == test_sms, f"SM count not set correctly: expected {test_sms}, got {new_sms}" |
|
|
|
# Restore original |
|
deep_gemm.set_num_sms(original_sms) |
|
|
|
print(f"✓ SM count utility test passed (SM count: {current_sms})") |
|
|
|
def test_tma_alignment(self, setup_gpu): |
|
"""Test TMA alignment utilities""" |
|
device = setup_gpu |
|
|
|
# Test TMA alignment size |
|
tma_alignment = deep_gemm.get_tma_aligned_size() |
|
assert tma_alignment > 0, "TMA alignment should be positive" |
|
assert tma_alignment % 16 == 0, "TMA alignment should be multiple of 16" |
|
|
|
# Test TMA aligned tensor creation |
|
shape = (256, 512) |
|
aligned_tensor = deep_gemm.get_col_major_tma_aligned_tensor(shape, torch.float8_e4m3fn) |
|
|
|
assert aligned_tensor.device.type == 'cuda' |
|
assert aligned_tensor.dtype == torch.float8_e4m3fn |
|
assert aligned_tensor.shape[0] >= shape[0] # May be padded |
|
assert aligned_tensor.shape[1] >= shape[1] # May be padded |
|
|
|
print(f"✓ TMA alignment test passed (alignment: {tma_alignment})") |
|
|
|
@pytest.mark.parametrize("scaling_mode", ["per_tensor", "per_channel"]) |
|
def test_scaling_modes(self, setup_gpu, random_seed, scaling_mode): |
|
"""Test different scaling granularities""" |
|
device = setup_gpu |
|
m, n, k = 256, 128, 256 |
|
|
|
x = self.create_fp8_tensor((m, k)) |
|
y = self.create_fp8_tensor((n, k)) |
|
|
|
if scaling_mode == "per_tensor": |
|
x_scale = self.create_scaling_factor((1,)) |
|
y_scale = self.create_scaling_factor((1,)) |
|
else: # per_channel |
|
x_scale = self.create_scaling_factor((k,)) |
|
y_scale = self.create_scaling_factor((k,)) |
|
|
|
out = torch.empty((m, n), dtype=torch.bfloat16, device='cuda') |
|
|
|
try: |
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x, y, out, x_scale, y_scale) |
|
assert torch.isfinite(out).all() |
|
print(f"✓ Scaling mode test passed: {scaling_mode}") |
|
except Exception as e: |
|
# Some scaling modes might not be supported |
|
print(f"⚠ Scaling mode {scaling_mode} not supported: {e}") |
|
|
|
def test_error_handling(self, setup_gpu): |
|
"""Test error handling for invalid inputs""" |
|
device = setup_gpu |
|
|
|
# Test mismatched dimensions |
|
x = self.create_fp8_tensor((128, 256)) |
|
y = self.create_fp8_tensor((64, 128)) # Wrong K dimension |
|
x_scale = self.create_scaling_factor((1,)) |
|
y_scale = self.create_scaling_factor((1,)) |
|
out = torch.empty((128, 64), dtype=torch.bfloat16, device='cuda') |
|
|
|
with pytest.raises(Exception): |
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x, y, out, x_scale, y_scale) |
|
|
|
print("✓ Error handling test passed") |
|
|
|
def test_performance_characteristics(self, setup_gpu, random_seed): |
|
"""Test and report performance characteristics""" |
|
device = setup_gpu |
|
|
|
shapes = [ |
|
(1024, 1024, 1024), |
|
(2048, 1024, 2048), |
|
(4096, 2048, 4096), |
|
] |
|
|
|
for m, n, k in shapes: |
|
try: |
|
x = self.create_fp8_tensor((m, k)) |
|
y = self.create_fp8_tensor((n, k)) |
|
x_scale = self.create_scaling_factor((1,)) |
|
y_scale = self.create_scaling_factor((1,)) |
|
out = torch.empty((m, n), dtype=torch.bfloat16, device='cuda') |
|
|
|
# Warmup |
|
for _ in range(3): |
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x, y, out, x_scale, y_scale) |
|
|
|
# Time the operation |
|
torch.cuda.synchronize() |
|
start = torch.cuda.Event(enable_timing=True) |
|
end = torch.cuda.Event(enable_timing=True) |
|
|
|
start.record() |
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x, y, out, x_scale, y_scale) |
|
end.record() |
|
torch.cuda.synchronize() |
|
|
|
elapsed_ms = start.elapsed_time(end) |
|
|
|
# Calculate TFLOPS |
|
flops = 2 * m * n * k # Each GEMM is 2*M*N*K FLOPs |
|
tflops = (flops / (elapsed_ms / 1000)) / 1e12 |
|
|
|
print(f"✓ Performance test for ({m}, {n}, {k}): {elapsed_ms:.2f}ms, {tflops:.1f} TFLOPS") |
|
|
|
except RuntimeError as e: |
|
if "out of memory" in str(e): |
|
print(f"⚠ Skipping large shape ({m}, {n}, {k}) due to memory constraints") |
|
else: |
|
raise |
|
|
|
|
|
# Additional utility tests |
|
class TestDeepGEMMUtilities: |
|
"""Test utility functions and edge cases""" |
|
|
|
def test_library_info(self): |
|
"""Test that we can access library information""" |
|
try: |
|
import deep_gemm |
|
# Basic import test |
|
assert hasattr(deep_gemm, 'gemm_fp8_fp8_bf16_nt') |
|
assert hasattr(deep_gemm, 'get_num_sms') |
|
assert hasattr(deep_gemm, 'get_m_alignment_for_contiguous_layout') |
|
print("✓ Library import and basic API test passed") |
|
except ImportError: |
|
pytest.skip("DeepGEMM not installed") |
|
|
|
def test_version_compatibility(self): |
|
"""Test CUDA and PyTorch version compatibility""" |
|
if torch.cuda.is_available(): |
|
cuda_version = torch.version.cuda |
|
pytorch_version = torch.__version__ |
|
|
|
print(f"CUDA version: {cuda_version}") |
|
print(f"PyTorch version: {pytorch_version}") |
|
|
|
# DeepGEMM requires CUDA 12.3+ |
|
if cuda_version: |
|
major, minor = map(int, cuda_version.split('.')[:2]) |
|
if major < 12 or (major == 12 and minor < 3): |
|
pytest.skip(f"DeepGEMM requires CUDA 12.3+, got {cuda_version}") |
|
|
|
print("✓ Version compatibility check passed") |
|
|
|
|
|
if __name__ == "__main__": |
|
# Run with: python -m pytest test_deepgemm.py -v |
|
pytest.main([__file__, "-v", "--tb=short"]) |