Skip to content

Instantly share code, notes, and snippets.

@ianbarber
Created March 17, 2026 15:43
Show Gist options
  • Select an option

  • Save ianbarber/d3dc33ac59a6f3f75aa0490e8fb795c0 to your computer and use it in GitHub Desktop.

Select an option

Save ianbarber/d3dc33ac59a6f3f75aa0490e8fb795c0 to your computer and use it in GitHub Desktop.
"""
Triton kernel for fused RMS normalization.
RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight
"""
import triton
import triton.language as tl
import torch
@triton.jit
def _rms_norm_fwd_kernel(
X_ptr,
W_ptr,
Y_ptr,
RRMS_ptr, # reciprocal RMS, saved for backward
stride_x_row,
N_COLS: tl.constexpr,
eps: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < N_COLS
# Load row
x_ptrs = X_ptr + row_idx * stride_x_row + col_offsets
x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)
# Compute RMS: sqrt(mean(x^2) + eps)
x_sq = x * x
mean_sq = tl.sum(x_sq, axis=0) / N_COLS
rrms = 1.0 / tl.sqrt(mean_sq + eps)
# Normalize
y = x * rrms
# Apply weight if present
if HAS_WEIGHT:
w = tl.load(W_ptr + col_offsets, mask=mask, other=1.0).to(tl.float32)
y = y * w
# Store output and rrms (for backward pass)
y_ptrs = Y_ptr + row_idx * stride_x_row + col_offsets
tl.store(y_ptrs, y.to(tl.float16) if y.dtype == tl.float16 else y, mask=mask)
tl.store(RRMS_ptr + row_idx, rrms)
def triton_rms_norm_forward(
x: torch.Tensor,
normalized_shape: list[int],
weight: torch.Tensor | None,
eps: float | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fused RMSNorm forward pass using Triton."""
if eps is None:
eps = 1e-6
# Flatten batch dims: (B, ..., N) -> (M, N)
orig_shape = x.shape
n_cols = normalized_shape[-1]
x_2d = x.reshape(-1, n_cols)
n_rows = x_2d.shape[0]
# Allocate outputs
y = torch.empty_like(x_2d)
rrms = torch.empty(n_rows, dtype=torch.float32, device=x.device)
# Pick block size (next power of 2 >= n_cols)
BLOCK_SIZE = triton.next_power_of_2(n_cols)
_rms_norm_fwd_kernel[(n_rows,)](
x_2d, weight, y, rrms,
x_2d.stride(0),
N_COLS=n_cols,
eps=eps,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=weight is not None,
)
return y.reshape(orig_shape), rrms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment