Skip to content

Instantly share code, notes, and snippets.

@drbh
Created May 8, 2025 14:59

Revisions

  1. drbh created this gist May 8, 2025.
    72 changes: 72 additions & 0 deletions rmsnorm_baseline.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,72 @@
    # /// script
    # dependencies = [
    # "numpy",
    # "torch",
    # "kernels",
    # ]
    # ///
    import torch
    import torch.nn as nn

    DEVICE = "cuda"

    DTYPE = torch.float16 # Use float16 for better kernel performance potential


    # Simple PyTorch implementation of RMSNorm for baseline comparison
    class RMSNorm(nn.Module):
    def __init__(self, hidden_size, variance_epsilon=1e-5):
    super().__init__()
    self.weight = nn.Parameter(torch.ones(hidden_size))
    self.eps = variance_epsilon
    self.hidden_size = hidden_size

    def forward(self, x):
    # Assumes x is (batch_size, ..., hidden_size)
    input_dtype = x.dtype
    # Calculate variance in float32 for stability
    variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
    x = x * torch.rsqrt(variance + self.eps)

    # Apply weight and convert back to original dtype
    return (self.weight * x).to(input_dtype)


    class BaselineModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, eps=1e-5):
    super().__init__()
    self.linear1 = nn.Linear(input_size, hidden_size)
    self.norm = RMSNorm(hidden_size, variance_epsilon=eps)
    self.activation = nn.GELU()
    self.linear2 = nn.Linear(hidden_size, output_size)

    # ensure all linear layers weights are 1 for testing
    with torch.no_grad():
    self.linear1.weight.fill_(1)
    self.linear1.bias.fill_(0)
    self.linear2.weight.fill_(1)
    self.linear2.bias.fill_(0)
    self.norm.weight.fill_(1)

    def forward(self, x):
    x = self.linear1(x)
    x = self.norm(x) # Apply RMSNorm
    x = self.activation(x)
    x = self.linear2(x)
    return x


    # Example usage
    input_size = 128
    hidden_size = 256
    output_size = 10
    eps_val = 1e-5

    baseline_model = (
    BaselineModel(input_size, hidden_size, output_size, eps=eps_val)
    .to(DEVICE)
    .to(DTYPE)
    )
    dummy_input = torch.randn(32, input_size, device=DEVICE, dtype=DTYPE) # Batch of 32
    output = baseline_model(dummy_input)
    print("Baseline RMSNorm model output shape:", output.shape)