Created
March 17, 2026 15:43
-
-
Save ianbarber/d3dc33ac59a6f3f75aa0490e8fb795c0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Triton kernel for fused RMS normalization. | |
| RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight | |
| """ | |
| import triton | |
| import triton.language as tl | |
| import torch | |
| @triton.jit | |
| def _rms_norm_fwd_kernel( | |
| X_ptr, | |
| W_ptr, | |
| Y_ptr, | |
| RRMS_ptr, # reciprocal RMS, saved for backward | |
| stride_x_row, | |
| N_COLS: tl.constexpr, | |
| eps: tl.constexpr, | |
| BLOCK_SIZE: tl.constexpr, | |
| HAS_WEIGHT: tl.constexpr, | |
| ): | |
| row_idx = tl.program_id(0) | |
| col_offsets = tl.arange(0, BLOCK_SIZE) | |
| mask = col_offsets < N_COLS | |
| # Load row | |
| x_ptrs = X_ptr + row_idx * stride_x_row + col_offsets | |
| x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) | |
| # Compute RMS: sqrt(mean(x^2) + eps) | |
| x_sq = x * x | |
| mean_sq = tl.sum(x_sq, axis=0) / N_COLS | |
| rrms = 1.0 / tl.sqrt(mean_sq + eps) | |
| # Normalize | |
| y = x * rrms | |
| # Apply weight if present | |
| if HAS_WEIGHT: | |
| w = tl.load(W_ptr + col_offsets, mask=mask, other=1.0).to(tl.float32) | |
| y = y * w | |
| # Store output and rrms (for backward pass) | |
| y_ptrs = Y_ptr + row_idx * stride_x_row + col_offsets | |
| tl.store(y_ptrs, y.to(tl.float16) if y.dtype == tl.float16 else y, mask=mask) | |
| tl.store(RRMS_ptr + row_idx, rrms) | |
| def triton_rms_norm_forward( | |
| x: torch.Tensor, | |
| normalized_shape: list[int], | |
| weight: torch.Tensor | None, | |
| eps: float | None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Fused RMSNorm forward pass using Triton.""" | |
| if eps is None: | |
| eps = 1e-6 | |
| # Flatten batch dims: (B, ..., N) -> (M, N) | |
| orig_shape = x.shape | |
| n_cols = normalized_shape[-1] | |
| x_2d = x.reshape(-1, n_cols) | |
| n_rows = x_2d.shape[0] | |
| # Allocate outputs | |
| y = torch.empty_like(x_2d) | |
| rrms = torch.empty(n_rows, dtype=torch.float32, device=x.device) | |
| # Pick block size (next power of 2 >= n_cols) | |
| BLOCK_SIZE = triton.next_power_of_2(n_cols) | |
| _rms_norm_fwd_kernel[(n_rows,)]( | |
| x_2d, weight, y, rrms, | |
| x_2d.stride(0), | |
| N_COLS=n_cols, | |
| eps=eps, | |
| BLOCK_SIZE=BLOCK_SIZE, | |
| HAS_WEIGHT=weight is not None, | |
| ) | |
| return y.reshape(orig_shape), rrms |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment