Skip to content

Instantly share code, notes, and snippets.

@scturtle
Last active July 4, 2025 09:57
Show Gist options
  • Save scturtle/717dac754b85944636ff5b09eca117e0 to your computer and use it in GitHub Desktop.
Save scturtle/717dac754b85944636ff5b09eca117e0 to your computer and use it in GitHub Desktop.
gemma3n
import os
import math
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoProcessor
@dataclass
class config:
activation_sparsity_pattern = [0.95] * 10 + [0.0] * 25
altup_active_idx = 0
altup_coef_clip = 120.0
altup_correct_scale = True
altup_lr_multiplier = 1.0
altup_num_inputs = 4
final_logit_softcapping = 30.0
head_dim = 256
hidden_size = 2048
hidden_size_per_layer_input = 256
intermediate_size = 16384
laurel_rank = 64
layer_types = (["sliding_attention"] * 4 + ["full_attention"]) * 7
max_position_embeddings = 32768
num_attention_heads = 8
num_hidden_layers = 35
num_key_value_heads = 2
num_kv_shared_layers = 15
rms_norm_eps = 1e-06
rope_local_base_freq = 10000.0
rope_theta = 1000000.0
sliding_window = 512
vocab_size = 262400
vocab_size_per_layer_input = 262144
_attn_implementation = "eager"
class Gemma3nRMSNorm(nn.Module):
def __init__(self, dim: int, with_scale: bool = True):
super().__init__()
self.eps = config.rms_norm_eps
self.with_scale = with_scale
if self.with_scale:
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.bfloat16))
else:
self.register_buffer(
"weight", torch.tensor(1.0, dtype=torch.bfloat16), persistent=False
)
def _norm(self, x):
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()) * self.weight.float()
return output.type_as(x)
class Gemma3nTextScaledWordEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, padding_idx, embed_scale):
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
dtype=torch.bfloat16,
device="meta",
)
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
def forward(self, input_ids):
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
class Gemma3nTextLaurelBlock(nn.Module):
def __init__(self):
super().__init__()
self.linear_left = nn.Linear(
config.hidden_size, config.laurel_rank, dtype=torch.bfloat16, bias=False
)
self.linear_right = nn.Linear(
config.laurel_rank, config.hidden_size, dtype=torch.bfloat16, bias=False
)
self.post_laurel_norm = Gemma3nRMSNorm(config.hidden_size)
def forward(self, hidden_states):
laurel_hidden_states = self.linear_left(hidden_states)
laurel_hidden_states = self.linear_right(laurel_hidden_states)
normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states)
return hidden_states + normed_laurel_hidden_states
class Gemma3nTextMLP(nn.Module):
def __init__(self, layer_idx: int = 0):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size,
self.intermediate_size,
dtype=torch.bfloat16,
bias=False,
device="meta",
)
self.up_proj = nn.Linear(
self.hidden_size,
self.intermediate_size,
dtype=torch.bfloat16,
bias=False,
device="meta",
)
self.down_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
dtype=torch.bfloat16,
bias=False,
device="meta",
)
self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
self.act_fn = nn.GELU("tanh")
def forward(self, hidden_states):
gate_proj = self.gate_proj(hidden_states)
if self.activation_sparsity > 0.0:
gate_proj = self._gaussian_topk(gate_proj)
activations = self.act_fn(gate_proj)
up_proj = self.up_proj(hidden_states)
down_proj = self.down_proj(activations * up_proj)
return down_proj
def _gaussian_topk(self, inputs):
target_sparsity_tensor = torch.tensor(
self.activation_sparsity, dtype=torch.float32, device=inputs.device
)
normal_dist = torch.distributions.normal.Normal(0, 1)
std_multiplier = normal_dist.icdf(target_sparsity_tensor)
std_multiplier = std_multiplier.type(inputs.dtype)
inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
cutoff_x = inputs_mean + inputs_std * std_multiplier
return nn.functional.relu(inputs - cutoff_x)
class Gemma3nTextAltUp(nn.Module):
def __init__(self):
super().__init__()
self.correct_output_scale = nn.Parameter(
torch.zeros(config.hidden_size, dtype=torch.bfloat16)
)
self.correction_coefs = nn.Linear(
config.altup_num_inputs,
config.altup_num_inputs,
dtype=torch.bfloat16,
bias=False,
)
self.prediction_coefs = nn.Linear(
config.altup_num_inputs,
config.altup_num_inputs**2,
dtype=torch.bfloat16,
bias=False,
)
self.modality_router = nn.Linear(
config.hidden_size,
config.altup_num_inputs,
dtype=torch.bfloat16,
bias=False,
)
self.router_norm = Gemma3nRMSNorm(config.hidden_size)
self.register_buffer(
"router_input_scale",
torch.tensor(config.hidden_size**-1.0, dtype=torch.bfloat16),
persistent=False,
)
def compute_router_modalities(self, x):
router_inputs = self.router_norm(x) * self.router_input_scale
routed = self.modality_router(router_inputs)
return torch.tanh(routed.float()).type_as(x)
def predict(self, hidden_states):
modalities = self.compute_router_modalities(
hidden_states[config.altup_active_idx]
)
all_coefs = (
self.prediction_coefs(modalities)
.reshape(
*modalities.shape[:-1], config.altup_num_inputs, config.altup_num_inputs
)
.permute(0, 1, 3, 2)
)
predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs)
predictions = predictions.permute(3, 0, 1, 2)
predictions += hidden_states
return predictions.contiguous().type_as(hidden_states)
def correct(self, predictions, activated):
modalities = self.compute_router_modalities(activated)
innovation = activated - predictions[config.altup_active_idx]
innovation = innovation.repeat(config.altup_num_inputs, 1, 1, 1)
if config.altup_coef_clip is not None:
self.correction_coefs.weight.data.clamp_(
-config.altup_coef_clip, config.altup_coef_clip
)
all_coefs = self.correction_coefs(modalities) + 1.0
all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1)
corrected = torch.mul(innovation, all_coefs)
corrected += predictions
return corrected.contiguous().type_as(activated)
def scale_corrected_output(self, corrected):
return (
corrected.type_as(self.correct_output_scale) * self.correct_output_scale
).type_as(corrected)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=2):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
return (x * cos) + (rotate_half(x) * sin)
class Gemma3nTextRotaryEmbedding(nn.Module):
def __init__(self, is_local=False):
super().__init__()
base = config.rope_local_base_freq if is_local else config.rope_theta
inv_freq = 1.0 / (
base
** (
torch.arange(0, config.head_dim, 2, dtype=torch.int64).to(
dtype=torch.float
)
/ config.head_dim
)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None]
.float()
.expand(position_ids.shape[0], -1, 1)
.to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
1, 2
)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
class Gemma3nTextAttention(nn.Module):
def __init__(self, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
self.sliding_window = config.sliding_window if self.is_sliding else None
self.head_dim = config.head_dim
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size,
config.num_attention_heads * self.head_dim,
dtype=torch.bfloat16,
bias=False,
device="meta",
)
self.k_proj = nn.Linear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
dtype=torch.bfloat16,
bias=False,
device="meta",
)
self.v_proj = nn.Linear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
dtype=torch.bfloat16,
bias=False,
device="meta",
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim,
config.hidden_size,
dtype=torch.bfloat16,
bias=False,
device="meta",
)
self.q_norm = Gemma3nRMSNorm(dim=config.head_dim)
self.k_norm = Gemma3nRMSNorm(dim=config.head_dim)
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, with_scale=False)
first_kv_shared_layer_idx = (
config.num_hidden_layers - config.num_kv_shared_layers
)
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
layer_type = config.layer_types[layer_idx]
self.kv_shared_layer_index = (
first_kv_shared_layer_idx
- 1
- config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type)
if self.is_kv_shared_layer
else None
)
def forward(
self,
hidden_states,
position_embeddings,
past_key_values=None,
layer_cache=None,
):
bsz, q_len, _ = hidden_states.shape
hidden_shape = (bsz, q_len, -1, config.head_dim)
cos, sin = position_embeddings
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin)
query_states = query_states.transpose(1, 2)
if self.is_kv_shared_layer:
key_states, value_states = layer_cache[self.kv_shared_layer_index]
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states)
key_states = apply_rotary_pos_emb(key_states, cos, sin)
key_states = key_states.transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape)
value_states = self.v_norm(value_states)
value_states = value_states.transpose(1, 2)
if layer_cache is not None:
layer_cache[self.layer_idx] = (key_states, value_states)
if past_key_values is not None:
past_key_value = past_key_values[
(
self.kv_shared_layer_index
if self.is_kv_shared_layer
else self.layer_idx
)
]
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
present_key_value = (
None if self.is_kv_shared_layer else (key_states, value_states)
)
# TODO: sliding window mask
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
is_causal=q_len > 1,
scale=1.0,
enable_gqa=True,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, present_key_value
class Gemma3nTextDecoderLayer(nn.Module):
def __init__(self, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
self.act_fn = nn.GELU("tanh")
self.mlp = Gemma3nTextMLP(layer_idx=layer_idx)
self.attention_type = config.layer_types[layer_idx]
self.input_layernorm = Gemma3nRMSNorm(self.hidden_size)
self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size)
self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size)
self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size)
self.altup = Gemma3nTextAltUp()
self.laurel = Gemma3nTextLaurelBlock()
self.self_attn = Gemma3nTextAttention(layer_idx)
self.per_layer_input_gate = nn.Linear(
self.hidden_size,
self.hidden_size_per_layer_input,
dtype=torch.bfloat16,
bias=False,
)
self.per_layer_projection = nn.Linear(
self.hidden_size_per_layer_input,
self.hidden_size,
dtype=torch.bfloat16,
bias=False,
)
self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size)
def forward(
self,
hidden_states,
position_embeddings_global,
position_embeddings_local,
per_layer_input,
past_key_values=None,
layer_cache=None,
):
predictions = self.altup.predict(hidden_states)
active_prediction = predictions[config.altup_active_idx]
active_prediction_normed = self.input_layernorm(active_prediction)
laurel_output = self.laurel(active_prediction_normed)
if self.self_attn.is_sliding:
position_embeddings = position_embeddings_local
else:
position_embeddings = position_embeddings_global
attn, present_key_value = self.self_attn(
hidden_states=active_prediction_normed,
position_embeddings=position_embeddings,
past_key_values=past_key_values,
layer_cache=layer_cache,
)
attn = self.post_attention_layernorm(attn)
attn_gated = active_prediction + attn
attn_laurel = (attn_gated + laurel_output) / math.sqrt(2)
attn_norm = self.pre_feedforward_layernorm(attn_laurel)
attn_ffw = self.mlp(attn_norm)
attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
first_prediction = corrected_predictions[config.altup_active_idx]
first_prediction_clone = first_prediction.clone()
if config.altup_correct_scale:
first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
first_prediction = self.per_layer_input_gate(first_prediction)
first_prediction = self.act_fn(first_prediction)
first_prediction = torch.multiply(first_prediction, per_layer_input)
first_prediction = self.per_layer_projection(first_prediction)
first_prediction = self.post_per_layer_input_norm(first_prediction)
corrected_predictions[1:] += first_prediction
return corrected_predictions, present_key_value
class Gemma3nTextModel(nn.Module):
def __init__(self):
super().__init__()
self.padding_idx = 0
self.hidden_size = config.hidden_size
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
self.embed_tokens = Gemma3nTextScaledWordEmbedding(
config.vocab_size,
config.hidden_size,
self.padding_idx,
embed_scale=config.hidden_size**0.5,
)
self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input,
self.padding_idx,
embed_scale=config.hidden_size_per_layer_input**0.5,
)
self.per_layer_model_projection = nn.Linear(
self.hidden_size,
config.num_hidden_layers * config.hidden_size_per_layer_input,
dtype=torch.bfloat16,
bias=False,
)
self.per_layer_projection_norm = Gemma3nRMSNorm(
config.hidden_size_per_layer_input,
)
self.layers = nn.ModuleList(
[
Gemma3nTextDecoderLayer(layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = Gemma3nRMSNorm(config.hidden_size)
self.altup_projections = nn.ModuleList(
[
nn.Linear(
self.hidden_size, self.hidden_size, dtype=torch.bfloat16, bias=False
)
for _ in range(1, config.altup_num_inputs)
]
)
self.altup_unembed_projections = nn.ModuleList(
[
nn.Linear(
self.hidden_size, self.hidden_size, dtype=torch.bfloat16, bias=False
)
for _ in range(1, config.altup_num_inputs)
]
)
self.register_buffer(
"per_layer_projection_scale",
torch.tensor(self.hidden_size**-0.5),
persistent=False,
)
self.register_buffer(
"per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False
)
self.rotary_emb = Gemma3nTextRotaryEmbedding()
self.rotary_emb_local = Gemma3nTextRotaryEmbedding(is_local=True)
def get_per_layer_inputs(self, input_ids: torch.LongTensor):
return self.embed_tokens_per_layer(input_ids).reshape(
*input_ids.shape,
config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
def project_per_layer_inputs(self, inputs_embeds, per_layer_inputs=None):
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection *= self.per_layer_projection_scale.type(
inputs_embeds.dtype
)
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
if per_layer_projection.shape != per_layer_inputs.shape:
per_layer_inputs = per_layer_inputs[..., : config.num_hidden_layers, :]
return (
per_layer_projection + per_layer_inputs
) * self.per_layer_input_scale.type(inputs_embeds.dtype)
def forward(
self,
input_ids,
past_key_values=None,
cache_position=None,
):
if cache_position is None:
cache_position = torch.arange(input_ids.shape[1], device=input_ids.device)
position_ids = cache_position.unsqueeze(0)
inputs_embeds = self.embed_tokens(input_ids)
per_layer_inputs = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(
inputs_embeds, per_layer_inputs
)
hidden_states_0 = inputs_embeds
position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids)
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
epsilon_tensor = torch.tensor(torch.finfo().min)
temp_hidden_states = [hidden_states_0]
for i in range(1, config.altup_num_inputs):
altup_proj = self.altup_projections[i - 1](hidden_states_0)
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
new_magnitude = (
torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
)
current_hidden_state = current_hidden_state * (
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
)
temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states, dim=0)
layer_cache = {}
next_key_values = []
for i, decoder_layer in enumerate(self.layers):
per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
layer_outputs, present_key_value = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
per_layer_input=per_layer_input,
past_key_values=past_key_values,
layer_cache=layer_cache,
)
hidden_states = layer_outputs
if present_key_value is not None:
next_key_values.append(present_key_value)
target_magnitude = (
torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
)
temp_hidden_states = [hidden_states[0]]
for i in range(1, config.altup_num_inputs):
altup_unemb_proj = self.altup_unembed_projections[i - 1](hidden_states[i])
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
new_magnitude = (
torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
)
current_hidden_state = current_hidden_state * (
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
)
temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states)
hidden_states = torch.mean(hidden_states, dim=0)
hidden_states = self.norm(hidden_states)
return hidden_states, next_key_values
class Gemma3nForCausalLM(nn.Module):
def __init__(self):
super().__init__()
self.language_model = Gemma3nTextModel()
self.lm_head = nn.Linear(
config.hidden_size,
config.vocab_size,
bias=False,
dtype=torch.bfloat16,
device="meta",
)
self.final_logit_softcapping = config.final_logit_softcapping
def forward(
self,
input_ids,
past_key_values=None,
cache_position=None,
):
hidden_states, next_past_key_values = self.language_model(
input_ids=input_ids,
past_key_values=past_key_values,
cache_position=cache_position,
)
logits = self.lm_head(hidden_states)
logits = logits / self.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.final_logit_softcapping
return logits, next_past_key_values
def load_model(model: nn.Module, path: str):
from glob import glob
from safetensors import safe_open
state_dict = {}
for file in glob(os.path.join(path, "*.safetensors")):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
loaded_tensor = f.get_tensor(weight_name)
if not weight_name.startswith("model.language_model."):
continue
weight_name = weight_name.replace(
"model.language_model.", "language_model."
)
param = model.get_parameter(weight_name)
assert param.dtype == loaded_tensor.dtype
state_dict[weight_name] = loaded_tensor
state_dict["lm_head.weight"] = state_dict["language_model.embed_tokens.weight"]
model.load_state_dict(state_dict, assign=True)
model.lm_head.weight.data = model.language_model.embed_tokens.weight.data
path = os.path.expanduser(
"~/.cache/huggingface/hub/models--google--gemma-3n-e4b-it/snapshots/17962761f8fbb66a0c7b8e9b4978b91909855a05"
)
model = Gemma3nForCausalLM()
load_model(model, path)
model.eval()
processor = AutoProcessor.from_pretrained("google/gemma-3n-e4b-it", padding_side="left")
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Answer to the Ultimate Question of Life, the Universe, and Everything is",
},
],
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
)
input_ids = inputs["input_ids"]
print(processor.decode(input_ids[0].tolist()), end="", flush=True)
max_new_tokens = 30
past_key_values = None
cache_position = None
with torch.no_grad():
outputs, past_key_values = model(input_ids)
next_logits = outputs[:, -1, :]
next_id = torch.argmax(next_logits, dim=-1).unsqueeze(-1)
print(processor.decode([next_id.item()]), end="", flush=True)
cache_position = torch.tensor([input_ids.shape[1]])
for _ in range(max_new_tokens - 1):
outputs, past_key_values = model(next_id, past_key_values, cache_position)
next_id = torch.argmax(outputs[:, -1, :], dim=-1).unsqueeze(-1)
cache_position += 1
print(processor.decode([next_id.item()]), end="", flush=True)
if next_id.item() == processor.tokenizer.eos_token_id:
break
print("\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment