Created
May 8, 2025 14:59
Revisions
-
drbh created this gist
May 8, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,72 @@ # /// script # dependencies = [ # "numpy", # "torch", # "kernels", # ] # /// import torch import torch.nn as nn DEVICE = "cuda" DTYPE = torch.float16 # Use float16 for better kernel performance potential # Simple PyTorch implementation of RMSNorm for baseline comparison class RMSNorm(nn.Module): def __init__(self, hidden_size, variance_epsilon=1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = variance_epsilon self.hidden_size = hidden_size def forward(self, x): # Assumes x is (batch_size, ..., hidden_size) input_dtype = x.dtype # Calculate variance in float32 for stability variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) # Apply weight and convert back to original dtype return (self.weight * x).to(input_dtype) class BaselineModel(nn.Module): def __init__(self, input_size, hidden_size, output_size, eps=1e-5): super().__init__() self.linear1 = nn.Linear(input_size, hidden_size) self.norm = RMSNorm(hidden_size, variance_epsilon=eps) self.activation = nn.GELU() self.linear2 = nn.Linear(hidden_size, output_size) # ensure all linear layers weights are 1 for testing with torch.no_grad(): self.linear1.weight.fill_(1) self.linear1.bias.fill_(0) self.linear2.weight.fill_(1) self.linear2.bias.fill_(0) self.norm.weight.fill_(1) def forward(self, x): x = self.linear1(x) x = self.norm(x) # Apply RMSNorm x = self.activation(x) x = self.linear2(x) return x # Example usage input_size = 128 hidden_size = 256 output_size = 10 eps_val = 1e-5 baseline_model = ( BaselineModel(input_size, hidden_size, output_size, eps=eps_val) .to(DEVICE) .to(DTYPE) ) dummy_input = torch.randn(32, input_size, device=DEVICE, dtype=DTYPE) # Batch of 32 output = baseline_model(dummy_input) print("Baseline RMSNorm model output shape:", output.shape)