Skip to content

Instantly share code, notes, and snippets.

@scturtle
Last active June 23, 2025 11:37
Show Gist options
  • Save scturtle/d81a1e35e335ddb0423cf338f3169460 to your computer and use it in GitHub Desktop.
Save scturtle/d81a1e35e335ddb0423cf338f3169460 to your computer and use it in GitHub Desktop.
qwen3
import os
from functools import lru_cache
import torch
from torch import nn
import torch.nn.functional as F
from transformers import Qwen3Config
from transformers import Qwen2TokenizerFast
def apply_rotary_emb(x, cos, sin):
cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2)
x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
return torch.cat((y1, y2), dim=-1).to(x.dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, head_size, rotary_dim, max_position_embeddings, base):
super().__init__()
self.head_size = head_size
assert rotary_dim == head_size
inv_freq = 1.0 / (
base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)
)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cache = torch.cat((freqs.cos(), freqs.sin()), dim=-1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def forward(self, positions, query, key):
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query = apply_rotary_emb(query, cos, sin).view(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key = apply_rotary_emb(key, cos, sin).view(key_shape)
return query, key
@lru_cache(1)
def get_rope(head_size, rotary_dim, max_position, base):
return RotaryEmbedding(head_size, rotary_dim, max_position, base)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=torch.bfloat16))
def forward(self, x, residual=None):
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x.add_(residual.to(torch.float32))
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x if residual is None else (x, residual)
class Attention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = None
self.v_cache = None
def forward(self, q, k, v):
b = q.shape[0]
q = q.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(b, -1, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = v.view(b, -1, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.k_cache is not None:
k = torch.cat([self.k_cache, k], dim=2)
v = torch.cat([self.v_cache, v], dim=2)
self.k_cache = k
self.v_cache = v
o = F.scaled_dot_product_attention(
q, k, v, is_causal=q.size(2) > 1, scale=self.scale, enable_gqa=True
)
return o.transpose(1, 2).reshape(b, -1, self.num_heads * self.head_dim)
class Qwen3Attention(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
num_kv_heads,
max_position,
head_dim,
rms_norm_eps,
qkv_bias,
rope_theta,
):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim or hidden_size // num_heads
self.q_size = num_heads * head_dim
self.kv_size = num_kv_heads * head_dim
self.scaling = head_dim**-0.5
self.qkv_proj = nn.Linear(
hidden_size,
self.q_size + 2 * self.kv_size,
bias=qkv_bias,
dtype=torch.bfloat16,
)
self.o_proj = nn.Linear(
num_heads * head_dim, hidden_size, bias=False, dtype=torch.bfloat16
)
self.rotary_emb = get_rope(
head_dim,
rotary_dim=head_dim,
max_position=max_position,
base=rope_theta,
)
self.attn = Attention(
num_heads, head_dim, self.scaling, num_kv_heads=num_kv_heads
)
self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps)
def forward(self, positions, hidden_states):
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q_by_head = q.view(-1, self.num_heads, self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(-1, self.num_kv_heads, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v)
return self.o_proj(o)
class Qwen3MLP(nn.Module):
def __init__(self, hidden_size, intermediate_size):
super().__init__()
self.gate_up_proj = nn.Linear(
hidden_size, intermediate_size * 2, bias=False, dtype=torch.bfloat16
)
self.down_proj = nn.Linear(
intermediate_size, hidden_size, bias=False, dtype=torch.bfloat16
)
def forward(self, x):
x, y = self.gate_up_proj(x).chunk(2, -1)
return self.down_proj(F.silu(x).mul_(y))
class Qwen3DecoderLayer(nn.Module):
def __init__(self, config: Qwen3Config):
super().__init__()
self.self_attn = Qwen3Attention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
qkv_bias=config.attention_bias,
rms_norm_eps=config.rms_norm_eps,
head_dim=config.head_dim,
rope_theta=config.rope_theta,
)
assert config.hidden_act == "silu"
self.mlp = Qwen3MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(self, positions, hidden_states, residual):
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
return self.mlp(hidden_states), residual
class Qwen3ForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, config: Qwen3Config):
super().__init__()
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, dtype=torch.bfloat16
)
self.layers = nn.ModuleList(
[Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=False, dtype=torch.bfloat16
)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.embed_tokens.weight.data
def forward(self, input_ids, positions):
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return self.lm_head(hidden_states)
def load_model(model: nn.Module, path: str, config: Qwen3Config):
from glob import glob
from safetensors import safe_open
name_map = model.packed_modules_mapping
q_size = config.head_dim * config.num_attention_heads
kv_size = config.head_dim * config.num_key_value_heads
qkv_slices = {
"q": slice(0, q_size),
"k": slice(q_size, q_size + kv_size),
"v": slice(q_size + kv_size, None),
}
mlp_slices = {
0: slice(0, config.intermediate_size),
1: slice(config.intermediate_size, None),
}
for file in glob(os.path.join(path, "*.safetensors")):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
loaded_tensor = f.get_tensor(weight_name)
weight_name = weight_name.replace("model.", "")
for part_name, (packed_name, shard_id) in name_map.items():
if part_name in weight_name:
param_name = weight_name.replace(part_name, packed_name)
param = model.get_parameter(param_name)
slice_map = qkv_slices if "qkv" in packed_name else mlp_slices
param.data[slice_map[shard_id], :] = loaded_tensor
break
else:
param = model.get_parameter(weight_name)
param.data.copy_(loaded_tensor)
def sample(logits, top_k, top_p, temperature):
logits = logits.float()
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k_values, _ = torch.topk(logits, top_k)
kth_value = top_k_values[:, -1].unsqueeze(-1)
indices_to_remove = logits < kth_value
logits[indices_to_remove] = -float("Inf")
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = -float("Inf")
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return next_token
path = os.path.expanduser(
"~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/e6de91484c29aa9480d55605af694f39b081c455/"
)
tokenizer = Qwen2TokenizerFast.from_pretrained(path)
config = Qwen3Config.from_pretrained(path)
# print(config)
model = Qwen3ForCausalLM(config)
load_model(model, path, config)
enable_thinking = False
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "list all prime numbers within 100"}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
# device = torch.device("mps")
device = torch.device("cpu")
model.to(device)
model.eval()
input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
generated_ids = input_ids
print(tokenizer.decode(input_ids[0].tolist()), end="", flush=True)
max_new_tokens = 256
if enable_thinking:
temperature = 0.6
top_p = 0.95
else:
temperature = 0.7
top_p = 0.8
top_k = 20
with torch.no_grad():
for i in range(max_new_tokens):
seq_len = generated_ids.shape[1]
positions = torch.arange(
0 if i == 0 else seq_len - 1, seq_len, device=device
).unsqueeze(0)
current_input_ids = generated_ids if i == 0 else generated_ids[:, -1:]
logits = model(current_input_ids, positions)[:, -1, :]
next_token = sample(logits, top_k, top_p, temperature)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
print(tokenizer.decode(next_token[0].tolist()), end="", flush=True)
if next_token.item() == tokenizer.eos_token_id:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment