Last active
July 4, 2025 09:57
-
-
Save scturtle/717dac754b85944636ff5b09eca117e0 to your computer and use it in GitHub Desktop.
gemma3n
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import math | |
from dataclasses import dataclass | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from transformers import AutoProcessor | |
@dataclass | |
class config: | |
activation_sparsity_pattern = [0.95] * 10 + [0.0] * 25 | |
altup_active_idx = 0 | |
altup_coef_clip = 120.0 | |
altup_correct_scale = True | |
altup_lr_multiplier = 1.0 | |
altup_num_inputs = 4 | |
final_logit_softcapping = 30.0 | |
head_dim = 256 | |
hidden_size = 2048 | |
hidden_size_per_layer_input = 256 | |
intermediate_size = 16384 | |
laurel_rank = 64 | |
layer_types = (["sliding_attention"] * 4 + ["full_attention"]) * 7 | |
max_position_embeddings = 32768 | |
num_attention_heads = 8 | |
num_hidden_layers = 35 | |
num_key_value_heads = 2 | |
num_kv_shared_layers = 15 | |
rms_norm_eps = 1e-06 | |
rope_local_base_freq = 10000.0 | |
rope_theta = 1000000.0 | |
sliding_window = 512 | |
vocab_size = 262400 | |
vocab_size_per_layer_input = 262144 | |
_attn_implementation = "eager" | |
class Gemma3nRMSNorm(nn.Module): | |
def __init__(self, dim: int, with_scale: bool = True): | |
super().__init__() | |
self.eps = config.rms_norm_eps | |
self.with_scale = with_scale | |
if self.with_scale: | |
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.bfloat16)) | |
else: | |
self.register_buffer( | |
"weight", torch.tensor(1.0, dtype=torch.bfloat16), persistent=False | |
) | |
def _norm(self, x): | |
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
def forward(self, x): | |
output = self._norm(x.float()) * self.weight.float() | |
return output.type_as(x) | |
class Gemma3nTextScaledWordEmbedding(nn.Embedding): | |
def __init__(self, num_embeddings, embedding_dim, padding_idx, embed_scale): | |
super().__init__( | |
num_embeddings, | |
embedding_dim, | |
padding_idx, | |
dtype=torch.bfloat16, | |
device="meta", | |
) | |
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) | |
def forward(self, input_ids): | |
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) | |
class Gemma3nTextLaurelBlock(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.linear_left = nn.Linear( | |
config.hidden_size, config.laurel_rank, dtype=torch.bfloat16, bias=False | |
) | |
self.linear_right = nn.Linear( | |
config.laurel_rank, config.hidden_size, dtype=torch.bfloat16, bias=False | |
) | |
self.post_laurel_norm = Gemma3nRMSNorm(config.hidden_size) | |
def forward(self, hidden_states): | |
laurel_hidden_states = self.linear_left(hidden_states) | |
laurel_hidden_states = self.linear_right(laurel_hidden_states) | |
normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states) | |
return hidden_states + normed_laurel_hidden_states | |
class Gemma3nTextMLP(nn.Module): | |
def __init__(self, layer_idx: int = 0): | |
super().__init__() | |
self.hidden_size = config.hidden_size | |
self.intermediate_size = config.intermediate_size | |
self.gate_proj = nn.Linear( | |
self.hidden_size, | |
self.intermediate_size, | |
dtype=torch.bfloat16, | |
bias=False, | |
device="meta", | |
) | |
self.up_proj = nn.Linear( | |
self.hidden_size, | |
self.intermediate_size, | |
dtype=torch.bfloat16, | |
bias=False, | |
device="meta", | |
) | |
self.down_proj = nn.Linear( | |
self.intermediate_size, | |
self.hidden_size, | |
dtype=torch.bfloat16, | |
bias=False, | |
device="meta", | |
) | |
self.activation_sparsity = config.activation_sparsity_pattern[layer_idx] | |
self.act_fn = nn.GELU("tanh") | |
def forward(self, hidden_states): | |
gate_proj = self.gate_proj(hidden_states) | |
if self.activation_sparsity > 0.0: | |
gate_proj = self._gaussian_topk(gate_proj) | |
activations = self.act_fn(gate_proj) | |
up_proj = self.up_proj(hidden_states) | |
down_proj = self.down_proj(activations * up_proj) | |
return down_proj | |
def _gaussian_topk(self, inputs): | |
target_sparsity_tensor = torch.tensor( | |
self.activation_sparsity, dtype=torch.float32, device=inputs.device | |
) | |
normal_dist = torch.distributions.normal.Normal(0, 1) | |
std_multiplier = normal_dist.icdf(target_sparsity_tensor) | |
std_multiplier = std_multiplier.type(inputs.dtype) | |
inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) | |
inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) | |
cutoff_x = inputs_mean + inputs_std * std_multiplier | |
return nn.functional.relu(inputs - cutoff_x) | |
class Gemma3nTextAltUp(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.correct_output_scale = nn.Parameter( | |
torch.zeros(config.hidden_size, dtype=torch.bfloat16) | |
) | |
self.correction_coefs = nn.Linear( | |
config.altup_num_inputs, | |
config.altup_num_inputs, | |
dtype=torch.bfloat16, | |
bias=False, | |
) | |
self.prediction_coefs = nn.Linear( | |
config.altup_num_inputs, | |
config.altup_num_inputs**2, | |
dtype=torch.bfloat16, | |
bias=False, | |
) | |
self.modality_router = nn.Linear( | |
config.hidden_size, | |
config.altup_num_inputs, | |
dtype=torch.bfloat16, | |
bias=False, | |
) | |
self.router_norm = Gemma3nRMSNorm(config.hidden_size) | |
self.register_buffer( | |
"router_input_scale", | |
torch.tensor(config.hidden_size**-1.0, dtype=torch.bfloat16), | |
persistent=False, | |
) | |
def compute_router_modalities(self, x): | |
router_inputs = self.router_norm(x) * self.router_input_scale | |
routed = self.modality_router(router_inputs) | |
return torch.tanh(routed.float()).type_as(x) | |
def predict(self, hidden_states): | |
modalities = self.compute_router_modalities( | |
hidden_states[config.altup_active_idx] | |
) | |
all_coefs = ( | |
self.prediction_coefs(modalities) | |
.reshape( | |
*modalities.shape[:-1], config.altup_num_inputs, config.altup_num_inputs | |
) | |
.permute(0, 1, 3, 2) | |
) | |
predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs) | |
predictions = predictions.permute(3, 0, 1, 2) | |
predictions += hidden_states | |
return predictions.contiguous().type_as(hidden_states) | |
def correct(self, predictions, activated): | |
modalities = self.compute_router_modalities(activated) | |
innovation = activated - predictions[config.altup_active_idx] | |
innovation = innovation.repeat(config.altup_num_inputs, 1, 1, 1) | |
if config.altup_coef_clip is not None: | |
self.correction_coefs.weight.data.clamp_( | |
-config.altup_coef_clip, config.altup_coef_clip | |
) | |
all_coefs = self.correction_coefs(modalities) + 1.0 | |
all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1) | |
corrected = torch.mul(innovation, all_coefs) | |
corrected += predictions | |
return corrected.contiguous().type_as(activated) | |
def scale_corrected_output(self, corrected): | |
return ( | |
corrected.type_as(self.correct_output_scale) * self.correct_output_scale | |
).type_as(corrected) | |
def rotate_half(x): | |
x1 = x[..., : x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=2): | |
cos = cos.unsqueeze(unsqueeze_dim) | |
sin = sin.unsqueeze(unsqueeze_dim) | |
return (x * cos) + (rotate_half(x) * sin) | |
class Gemma3nTextRotaryEmbedding(nn.Module): | |
def __init__(self, is_local=False): | |
super().__init__() | |
base = config.rope_local_base_freq if is_local else config.rope_theta | |
inv_freq = 1.0 / ( | |
base | |
** ( | |
torch.arange(0, config.head_dim, 2, dtype=torch.int64).to( | |
dtype=torch.float | |
) | |
/ config.head_dim | |
) | |
) | |
self.register_buffer("inv_freq", inv_freq, persistent=False) | |
def forward(self, x, position_ids): | |
inv_freq_expanded = ( | |
self.inv_freq[None, :, None] | |
.float() | |
.expand(position_ids.shape[0], -1, 1) | |
.to(x.device) | |
) | |
position_ids_expanded = position_ids[:, None, :].float() | |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( | |
1, 2 | |
) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) | |
class Gemma3nTextAttention(nn.Module): | |
def __init__(self, layer_idx: int): | |
super().__init__() | |
self.layer_idx = layer_idx | |
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" | |
self.sliding_window = config.sliding_window if self.is_sliding else None | |
self.head_dim = config.head_dim | |
self.num_key_value_groups = ( | |
config.num_attention_heads // config.num_key_value_heads | |
) | |
self.is_causal = True | |
self.q_proj = nn.Linear( | |
config.hidden_size, | |
config.num_attention_heads * self.head_dim, | |
dtype=torch.bfloat16, | |
bias=False, | |
device="meta", | |
) | |
self.k_proj = nn.Linear( | |
config.hidden_size, | |
config.num_key_value_heads * self.head_dim, | |
dtype=torch.bfloat16, | |
bias=False, | |
device="meta", | |
) | |
self.v_proj = nn.Linear( | |
config.hidden_size, | |
config.num_key_value_heads * self.head_dim, | |
dtype=torch.bfloat16, | |
bias=False, | |
device="meta", | |
) | |
self.o_proj = nn.Linear( | |
config.num_attention_heads * self.head_dim, | |
config.hidden_size, | |
dtype=torch.bfloat16, | |
bias=False, | |
device="meta", | |
) | |
self.q_norm = Gemma3nRMSNorm(dim=config.head_dim) | |
self.k_norm = Gemma3nRMSNorm(dim=config.head_dim) | |
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, with_scale=False) | |
first_kv_shared_layer_idx = ( | |
config.num_hidden_layers - config.num_kv_shared_layers | |
) | |
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx | |
layer_type = config.layer_types[layer_idx] | |
self.kv_shared_layer_index = ( | |
first_kv_shared_layer_idx | |
- 1 | |
- config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type) | |
if self.is_kv_shared_layer | |
else None | |
) | |
def forward( | |
self, | |
hidden_states, | |
position_embeddings, | |
past_key_values=None, | |
layer_cache=None, | |
): | |
bsz, q_len, _ = hidden_states.shape | |
hidden_shape = (bsz, q_len, -1, config.head_dim) | |
cos, sin = position_embeddings | |
query_states = self.q_proj(hidden_states).view(hidden_shape) | |
query_states = self.q_norm(query_states) | |
query_states = apply_rotary_pos_emb(query_states, cos, sin) | |
query_states = query_states.transpose(1, 2) | |
if self.is_kv_shared_layer: | |
key_states, value_states = layer_cache[self.kv_shared_layer_index] | |
else: | |
key_states = self.k_proj(hidden_states).view(hidden_shape) | |
key_states = self.k_norm(key_states) | |
key_states = apply_rotary_pos_emb(key_states, cos, sin) | |
key_states = key_states.transpose(1, 2) | |
value_states = self.v_proj(hidden_states).view(hidden_shape) | |
value_states = self.v_norm(value_states) | |
value_states = value_states.transpose(1, 2) | |
if layer_cache is not None: | |
layer_cache[self.layer_idx] = (key_states, value_states) | |
if past_key_values is not None: | |
past_key_value = past_key_values[ | |
( | |
self.kv_shared_layer_index | |
if self.is_kv_shared_layer | |
else self.layer_idx | |
) | |
] | |
key_states = torch.cat([past_key_value[0], key_states], dim=2) | |
value_states = torch.cat([past_key_value[1], value_states], dim=2) | |
present_key_value = ( | |
None if self.is_kv_shared_layer else (key_states, value_states) | |
) | |
# TODO: sliding window mask | |
attn_output = F.scaled_dot_product_attention( | |
query_states, | |
key_states, | |
value_states, | |
is_causal=q_len > 1, | |
scale=1.0, | |
enable_gqa=True, | |
) | |
attn_output = attn_output.transpose(1, 2).contiguous() | |
attn_output = attn_output.reshape(bsz, q_len, -1) | |
attn_output = self.o_proj(attn_output) | |
return attn_output, present_key_value | |
class Gemma3nTextDecoderLayer(nn.Module): | |
def __init__(self, layer_idx: int): | |
super().__init__() | |
self.layer_idx = layer_idx | |
self.hidden_size = config.hidden_size | |
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input | |
self.act_fn = nn.GELU("tanh") | |
self.mlp = Gemma3nTextMLP(layer_idx=layer_idx) | |
self.attention_type = config.layer_types[layer_idx] | |
self.input_layernorm = Gemma3nRMSNorm(self.hidden_size) | |
self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size) | |
self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size) | |
self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size) | |
self.altup = Gemma3nTextAltUp() | |
self.laurel = Gemma3nTextLaurelBlock() | |
self.self_attn = Gemma3nTextAttention(layer_idx) | |
self.per_layer_input_gate = nn.Linear( | |
self.hidden_size, | |
self.hidden_size_per_layer_input, | |
dtype=torch.bfloat16, | |
bias=False, | |
) | |
self.per_layer_projection = nn.Linear( | |
self.hidden_size_per_layer_input, | |
self.hidden_size, | |
dtype=torch.bfloat16, | |
bias=False, | |
) | |
self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size) | |
def forward( | |
self, | |
hidden_states, | |
position_embeddings_global, | |
position_embeddings_local, | |
per_layer_input, | |
past_key_values=None, | |
layer_cache=None, | |
): | |
predictions = self.altup.predict(hidden_states) | |
active_prediction = predictions[config.altup_active_idx] | |
active_prediction_normed = self.input_layernorm(active_prediction) | |
laurel_output = self.laurel(active_prediction_normed) | |
if self.self_attn.is_sliding: | |
position_embeddings = position_embeddings_local | |
else: | |
position_embeddings = position_embeddings_global | |
attn, present_key_value = self.self_attn( | |
hidden_states=active_prediction_normed, | |
position_embeddings=position_embeddings, | |
past_key_values=past_key_values, | |
layer_cache=layer_cache, | |
) | |
attn = self.post_attention_layernorm(attn) | |
attn_gated = active_prediction + attn | |
attn_laurel = (attn_gated + laurel_output) / math.sqrt(2) | |
attn_norm = self.pre_feedforward_layernorm(attn_laurel) | |
attn_ffw = self.mlp(attn_norm) | |
attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) | |
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm | |
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) | |
first_prediction = corrected_predictions[config.altup_active_idx] | |
first_prediction_clone = first_prediction.clone() | |
if config.altup_correct_scale: | |
first_prediction = self.altup.scale_corrected_output(first_prediction_clone) | |
first_prediction = self.per_layer_input_gate(first_prediction) | |
first_prediction = self.act_fn(first_prediction) | |
first_prediction = torch.multiply(first_prediction, per_layer_input) | |
first_prediction = self.per_layer_projection(first_prediction) | |
first_prediction = self.post_per_layer_input_norm(first_prediction) | |
corrected_predictions[1:] += first_prediction | |
return corrected_predictions, present_key_value | |
class Gemma3nTextModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.padding_idx = 0 | |
self.hidden_size = config.hidden_size | |
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input | |
self.embed_tokens = Gemma3nTextScaledWordEmbedding( | |
config.vocab_size, | |
config.hidden_size, | |
self.padding_idx, | |
embed_scale=config.hidden_size**0.5, | |
) | |
self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding( | |
config.vocab_size_per_layer_input, | |
config.num_hidden_layers * config.hidden_size_per_layer_input, | |
self.padding_idx, | |
embed_scale=config.hidden_size_per_layer_input**0.5, | |
) | |
self.per_layer_model_projection = nn.Linear( | |
self.hidden_size, | |
config.num_hidden_layers * config.hidden_size_per_layer_input, | |
dtype=torch.bfloat16, | |
bias=False, | |
) | |
self.per_layer_projection_norm = Gemma3nRMSNorm( | |
config.hidden_size_per_layer_input, | |
) | |
self.layers = nn.ModuleList( | |
[ | |
Gemma3nTextDecoderLayer(layer_idx) | |
for layer_idx in range(config.num_hidden_layers) | |
] | |
) | |
self.norm = Gemma3nRMSNorm(config.hidden_size) | |
self.altup_projections = nn.ModuleList( | |
[ | |
nn.Linear( | |
self.hidden_size, self.hidden_size, dtype=torch.bfloat16, bias=False | |
) | |
for _ in range(1, config.altup_num_inputs) | |
] | |
) | |
self.altup_unembed_projections = nn.ModuleList( | |
[ | |
nn.Linear( | |
self.hidden_size, self.hidden_size, dtype=torch.bfloat16, bias=False | |
) | |
for _ in range(1, config.altup_num_inputs) | |
] | |
) | |
self.register_buffer( | |
"per_layer_projection_scale", | |
torch.tensor(self.hidden_size**-0.5), | |
persistent=False, | |
) | |
self.register_buffer( | |
"per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False | |
) | |
self.rotary_emb = Gemma3nTextRotaryEmbedding() | |
self.rotary_emb_local = Gemma3nTextRotaryEmbedding(is_local=True) | |
def get_per_layer_inputs(self, input_ids: torch.LongTensor): | |
return self.embed_tokens_per_layer(input_ids).reshape( | |
*input_ids.shape, | |
config.num_hidden_layers, | |
self.hidden_size_per_layer_input, | |
) | |
def project_per_layer_inputs(self, inputs_embeds, per_layer_inputs=None): | |
per_layer_projection = self.per_layer_model_projection(inputs_embeds) | |
per_layer_projection *= self.per_layer_projection_scale.type( | |
inputs_embeds.dtype | |
) | |
per_layer_projection = per_layer_projection.reshape( | |
*inputs_embeds.shape[:-1], | |
config.num_hidden_layers, | |
self.hidden_size_per_layer_input, | |
) | |
per_layer_projection = self.per_layer_projection_norm(per_layer_projection) | |
if per_layer_inputs is None: | |
return per_layer_projection | |
if per_layer_projection.shape != per_layer_inputs.shape: | |
per_layer_inputs = per_layer_inputs[..., : config.num_hidden_layers, :] | |
return ( | |
per_layer_projection + per_layer_inputs | |
) * self.per_layer_input_scale.type(inputs_embeds.dtype) | |
def forward( | |
self, | |
input_ids, | |
past_key_values=None, | |
cache_position=None, | |
): | |
if cache_position is None: | |
cache_position = torch.arange(input_ids.shape[1], device=input_ids.device) | |
position_ids = cache_position.unsqueeze(0) | |
inputs_embeds = self.embed_tokens(input_ids) | |
per_layer_inputs = self.get_per_layer_inputs(input_ids) | |
per_layer_inputs = self.project_per_layer_inputs( | |
inputs_embeds, per_layer_inputs | |
) | |
hidden_states_0 = inputs_embeds | |
position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids) | |
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) | |
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 | |
epsilon_tensor = torch.tensor(torch.finfo().min) | |
temp_hidden_states = [hidden_states_0] | |
for i in range(1, config.altup_num_inputs): | |
altup_proj = self.altup_projections[i - 1](hidden_states_0) | |
current_hidden_state = altup_proj.type(hidden_states_0.dtype) | |
new_magnitude = ( | |
torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 | |
) | |
current_hidden_state = current_hidden_state * ( | |
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) | |
) | |
temp_hidden_states.append(current_hidden_state) | |
hidden_states = torch.stack(temp_hidden_states, dim=0) | |
layer_cache = {} | |
next_key_values = [] | |
for i, decoder_layer in enumerate(self.layers): | |
per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] | |
layer_outputs, present_key_value = decoder_layer( | |
hidden_states, | |
position_embeddings_global=position_embeddings_global, | |
position_embeddings_local=position_embeddings_local, | |
per_layer_input=per_layer_input, | |
past_key_values=past_key_values, | |
layer_cache=layer_cache, | |
) | |
hidden_states = layer_outputs | |
if present_key_value is not None: | |
next_key_values.append(present_key_value) | |
target_magnitude = ( | |
torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5 | |
) | |
temp_hidden_states = [hidden_states[0]] | |
for i in range(1, config.altup_num_inputs): | |
altup_unemb_proj = self.altup_unembed_projections[i - 1](hidden_states[i]) | |
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) | |
new_magnitude = ( | |
torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 | |
) | |
current_hidden_state = current_hidden_state * ( | |
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) | |
) | |
temp_hidden_states.append(current_hidden_state) | |
hidden_states = torch.stack(temp_hidden_states) | |
hidden_states = torch.mean(hidden_states, dim=0) | |
hidden_states = self.norm(hidden_states) | |
return hidden_states, next_key_values | |
class Gemma3nForCausalLM(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.language_model = Gemma3nTextModel() | |
self.lm_head = nn.Linear( | |
config.hidden_size, | |
config.vocab_size, | |
bias=False, | |
dtype=torch.bfloat16, | |
device="meta", | |
) | |
self.final_logit_softcapping = config.final_logit_softcapping | |
def forward( | |
self, | |
input_ids, | |
past_key_values=None, | |
cache_position=None, | |
): | |
hidden_states, next_past_key_values = self.language_model( | |
input_ids=input_ids, | |
past_key_values=past_key_values, | |
cache_position=cache_position, | |
) | |
logits = self.lm_head(hidden_states) | |
logits = logits / self.final_logit_softcapping | |
logits = torch.tanh(logits) | |
logits = logits * self.final_logit_softcapping | |
return logits, next_past_key_values | |
def load_model(model: nn.Module, path: str): | |
from glob import glob | |
from safetensors import safe_open | |
state_dict = {} | |
for file in glob(os.path.join(path, "*.safetensors")): | |
with safe_open(file, "pt", "cpu") as f: | |
for weight_name in f.keys(): | |
loaded_tensor = f.get_tensor(weight_name) | |
if not weight_name.startswith("model.language_model."): | |
continue | |
weight_name = weight_name.replace( | |
"model.language_model.", "language_model." | |
) | |
param = model.get_parameter(weight_name) | |
assert param.dtype == loaded_tensor.dtype | |
state_dict[weight_name] = loaded_tensor | |
state_dict["lm_head.weight"] = state_dict["language_model.embed_tokens.weight"] | |
model.load_state_dict(state_dict, assign=True) | |
model.lm_head.weight.data = model.language_model.embed_tokens.weight.data | |
path = os.path.expanduser( | |
"~/.cache/huggingface/hub/models--google--gemma-3n-e4b-it/snapshots/17962761f8fbb66a0c7b8e9b4978b91909855a05" | |
) | |
model = Gemma3nForCausalLM() | |
load_model(model, path) | |
model.eval() | |
processor = AutoProcessor.from_pretrained("google/gemma-3n-e4b-it", padding_side="left") | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": "You are a helpful assistant."}], | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "Answer to the Ultimate Question of Life, the Universe, and Everything is", | |
}, | |
], | |
}, | |
] | |
inputs = processor.apply_chat_template( | |
messages, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
add_generation_prompt=True, | |
) | |
input_ids = inputs["input_ids"] | |
print(processor.decode(input_ids[0].tolist()), end="", flush=True) | |
max_new_tokens = 30 | |
past_key_values = None | |
cache_position = None | |
with torch.no_grad(): | |
outputs, past_key_values = model(input_ids) | |
next_logits = outputs[:, -1, :] | |
next_id = torch.argmax(next_logits, dim=-1).unsqueeze(-1) | |
print(processor.decode([next_id.item()]), end="", flush=True) | |
cache_position = torch.tensor([input_ids.shape[1]]) | |
for _ in range(max_new_tokens - 1): | |
outputs, past_key_values = model(next_id, past_key_values, cache_position) | |
next_id = torch.argmax(outputs[:, -1, :], dim=-1).unsqueeze(-1) | |
cache_position += 1 | |
print(processor.decode([next_id.item()]), end="", flush=True) | |
if next_id.item() == processor.tokenizer.eos_token_id: | |
break | |
print("\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment