Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active May 8, 2025 04:01
Show Gist options
  • Save pszemraj/d5e7c8af0c6b19ecf4551efc0c7f83eb to your computer and use it in GitHub Desktop.
Save pszemraj/d5e7c8af0c6b19ecf4551efc0c7f83eb to your computer and use it in GitHub Desktop.
pytorch impl for pretraining-free (directly finetune) wavenet, tiny transformer for classification
"""
WaveNet: An Ultra-Small Language Model (PyTorch Implementation)
Based on the paper: https://arxiv.org/abs/2411.02674
Hugging Face Transformers compatible implementation.
"""
import math
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN # Use HF activations
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
SequenceClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
# --- Configuration ---
class WaveNetConfig(PretrainedConfig):
"""
Configuration class for WaveNet models. Inherits from `PretrainedConfig`.
Args:
vocab_size (`int`, *optional*, defaults to 30522): Vocabulary size.
hidden_size (`int`, *optional*, defaults to 768): Dimension of embeddings and hidden states.
max_position_embeddings (`int`, *optional*, defaults to 512): Maximum sequence length.
initializer_range (`float`, *optional*, defaults to 0.02): Standard deviation for weight initialization.
layer_norm_eps (`float`, *optional*, defaults to 1e-12): Epsilon for LayerNorm.
dropout_prob (`float`, *optional*, defaults to 0.1): Dropout probability for embeddings and FFN.
operation_type (`str`, *optional*, defaults to "modulation"): Wave superposition type ("interference" or "modulation").
ffn_intermediate_factor (`int`, *optional*, defaults to 4): Factor to determine FFN intermediate size (intermediate = hidden * factor).
classifier_dropout (`float`, *optional*, defaults to None): Dropout for the classification head. If None, uses `dropout_prob`.
# Other standard PretrainedConfig args (num_labels, pad_token_id, return_dict, output_hidden_states, etc.)
# are handled via **kwargs by the PretrainedConfig base class.
"""
model_type = "wavenet"
def __init__(
self,
vocab_size: int = 30522,
hidden_size: int = 768,
max_position_embeddings: int = 512,
initializer_range: float = 0.02,
layer_norm_eps: float = 1e-12,
dropout_prob: float = 0.1,
operation_type: str = "modulation",
ffn_intermediate_factor: int = 4,
classifier_dropout: Optional[float] = None,
**kwargs,
):
# Pass standard args (like num_labels, pad_token_id, return_dict, output_*) to base
# PretrainedConfig will handle them.
super().__init__(**kwargs)
# Set model-specific attributes
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.dropout_prob = dropout_prob
self.operation_type = operation_type
self.ffn_intermediate_size = hidden_size * ffn_intermediate_factor
self.classifier_dropout = (
classifier_dropout if classifier_dropout is not None else dropout_prob
)
if self.operation_type not in ["interference", "modulation"]:
raise ValueError("operation_type must be 'interference' or 'modulation'")
# Attributes like self.num_labels, self.pad_token_id, self.return_dict,
# self.output_hidden_states, self.output_attentions are set by PretrainedConfig
# based on what's in kwargs. For example, if 'num_labels' is in kwargs,
# PretrainedConfig will use it, otherwise it will use its own default (e.g., 2).
# The same applies to pad_token_id (PretrainedConfig default is None).
# --- Helper: Complex Vector Encoder ---
class ComplexVectorEncoder(nn.Module):
"""Encodes real-valued token embeddings into complex vector representations."""
def __init__(self, config: WaveNetConfig):
super().__init__()
self.config = config
self.hidden_size: int = config.hidden_size
def forward(
self,
real_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
real_embeddings: Shape `(batch_size, seq_len, hidden_size)`
attention_mask: Shape `(batch_size, seq_len)`, 1 for valid, 0 for pad.
Returns:
complex_vectors: Shape `(batch_size, seq_len, hidden_size)`, dtype=complex
"""
# 1. Global Semantics (Magnitude G)
if attention_mask is not None:
# Ensure attention_mask is float for multiplication if embeddings are float
expanded_mask = (
attention_mask.unsqueeze(-1)
.expand_as(real_embeddings)
.to(real_embeddings.dtype)
)
masked_embeddings = real_embeddings * expanded_mask
else:
masked_embeddings = real_embeddings
# Sum of squares over the sequence length dimension for each hidden dimension
# masked_embeddings is (batch_size, seq_len, hidden_size)
# We want G_k = sqrt(sum_over_tokens(w_token,k^2))
# So sum should be over dim=1 (seq_len)
sum_sq_embeddings_per_dim = torch.sum(
masked_embeddings**2, dim=1
) # (batch_size, hidden_size)
G = torch.sqrt(
sum_sq_embeddings_per_dim.clamp(min=1e-12)
) # (batch_size, hidden_size)
G_broadcast = G.unsqueeze(
1
) # (batch_size, 1, hidden_size), for broadcasting over seq_len
G_broadcast_clamped = G_broadcast.clamp(min=1e-6) # Avoid division by zero
# 2. Local Semantics (Phase alpha)
# real_embeddings is (batch_size, seq_len, hidden_size)
# G_broadcast_clamped is (batch_size, 1, hidden_size)
cos_val = real_embeddings / G_broadcast_clamped # w_j,k / G_k
cos_val = torch.clamp(
cos_val, -1.0 + 1e-7, 1.0 - 1e-7
) # Ensure valid input for acos/sqrt
# sin_val = -sqrt(1 - cos_val^2) as per paper's atan2(y,x) where y is the -sqrt term
sin_val_sq = (1 - cos_val**2).clamp(
min=0.0
) # Ensure non-negative before sqrt
y_term_for_atan2 = -torch.sqrt(sin_val_sq)
x_term_for_atan2 = cos_val
alpha = torch.atan2(
y_term_for_atan2, x_term_for_atan2
) # (batch_size, seq_len, hidden_size)
# 3. Form Complex Vector Z = G * e^(i*alpha)
# G_broadcast is (batch_size, 1, hidden_size)
# alpha is (batch_size, seq_len, hidden_size)
# Result should be (batch_size, seq_len, hidden_size)
real_part = G_broadcast * torch.cos(alpha)
imag_part = G_broadcast * torch.sin(alpha)
complex_vectors = torch.complex(real_part, imag_part)
return complex_vectors
# --- WaveNet Embeddings (Token + Positional) ---
class WaveNetEmbeddings(nn.Module):
"""Constructs token and learned positional embeddings."""
def __init__(self, config: WaveNetConfig):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.dropout_prob)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
)
# For Hugging Face compatibility, to know the max length
self.max_position_embeddings = config.max_position_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
# token_type_ids is not used by WaveNet but might be passed by HF pipelines
token_type_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
seq_length = input_shape[1]
if position_ids is None:
# Create position_ids on the fly
# Ensure it's on the same device as input_ids or inputs_embeds
device = input_ids.device if input_ids is not None else inputs_embeds.device
right_padded_idx = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = right_padded_idx.unsqueeze(0).expand(input_shape)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = inputs_embeds + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
# --- WaveNet Core Layer (Single Layer) ---
class WaveNetSingleLayer(nn.Module):
"""Implements the core WaveNet layer logic."""
def __init__(self, config: WaveNetConfig):
super().__init__()
self.config = config
self.hidden_size: int = config.hidden_size
self.variant_project1 = nn.Linear(config.hidden_size, config.hidden_size)
self.variant_project2 = nn.Linear(config.hidden_size, config.hidden_size)
self.complex_encoder = ComplexVectorEncoder(config)
self.ffn_dense1 = nn.Linear(
2 * config.hidden_size,
config.ffn_intermediate_size, # Input is concat(real, imag)
)
self.ffn_activation = ACT2FN[
self.config.hidden_act if hasattr(self.config, "hidden_act") else "gelu"
]
self.ffn_dense2 = nn.Linear(config.ffn_intermediate_size, config.hidden_size)
self.output_dropout = nn.Dropout(config.dropout_prob)
self.output_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
def forward(
self,
initial_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, # This is the padding mask
) -> torch.Tensor:
"""Processes input embeddings through the WaveNet layer."""
# Project to get inputs for two variants
projected_embeddings1 = self.variant_project1(initial_embeddings)
projected_embeddings2 = self.variant_project2(initial_embeddings)
# Generate complex vectors
# The attention_mask here is the padding mask, used by ComplexVectorEncoder
# to correctly calculate global semantics G
Z1 = self.complex_encoder(projected_embeddings1, attention_mask=attention_mask)
Z2 = self.complex_encoder(projected_embeddings2, attention_mask=attention_mask)
# Wave Superposition
if self.config.operation_type == "interference":
Z_combined = Z1 + Z2
elif self.config.operation_type == "modulation":
Z_combined = Z1 * Z2
else:
raise ValueError(f"Unknown operation_type: {self.config.operation_type}")
# Feed Forward Network on real and imaginary parts
ffn_input = torch.cat((Z_combined.real, Z_combined.imag), dim=-1)
hidden_states_ffn = self.ffn_dense1(ffn_input)
hidden_states_ffn = self.ffn_activation(hidden_states_ffn)
hidden_states_ffn = self.ffn_dense2(hidden_states_ffn)
# Apply dropout, residual connection, then LayerNorm (Post-LN style)
hidden_states = initial_embeddings + self.output_dropout(hidden_states_ffn)
processed_representation = self.output_layernorm(hidden_states)
return processed_representation
# --- Base Model Class ---
class WaveNetPreTrainedModel(PreTrainedModel):
config_class = WaveNetConfig
base_model_prefix = "wavenet"
supports_gradient_checkpointing = True
_no_split_modules = ["WaveNetEmbeddings", "WaveNetSingleLayer"]
def _init_weights(self, module: nn.Module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, WaveNetModel):
module.gradient_checkpointing = value
# --- Core WaveNet Model ---
class WaveNetModel(WaveNetPreTrainedModel):
def __init__(self, config: WaveNetConfig):
super().__init__(config)
self.config = config
self.embeddings = WaveNetEmbeddings(config)
self.wave_layer = WaveNetSingleLayer(config)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings.word_embeddings
def set_input_embeddings(self, value: nn.Embedding):
self.embeddings.word_embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPooling]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if input_ids is not None:
device = input_ids.device
batch_size, seq_length = input_ids.shape
else:
device = inputs_embeds.device
batch_size, seq_length = inputs_embeds.shape[:2]
if attention_mask is None:
attention_mask = torch.ones(
((batch_size, seq_length)), device=device, dtype=torch.long
)
# WaveNetEmbeddings expects token_type_ids, even if unused by its logic, for HF compatibility
initial_embeddings = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
token_type_ids=token_type_ids,
)
hidden_states_current_layer = initial_embeddings
all_hidden_states = (initial_embeddings,) if output_hidden_states else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs) # attention_mask is the second input
return custom_forward
hidden_states_current_layer = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.wave_layer),
hidden_states_current_layer,
attention_mask, # Pass the padding mask
use_reentrant=self.config.get(
"use_reentrant", False
), # PyTorch default is True, HF often sets False
)
else:
hidden_states_current_layer = self.wave_layer(
hidden_states_current_layer, attention_mask=attention_mask
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states_current_layer,)
pooled_output = hidden_states_current_layer[:, 0] # CLS token pooling
if not return_dict:
outputs = (hidden_states_current_layer, pooled_output)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
return outputs
return BaseModelOutputWithPooling(
last_hidden_state=hidden_states_current_layer,
pooler_output=pooled_output,
hidden_states=all_hidden_states,
attentions=None, # WaveNet does not use attention
)
# --- Classification Head Model ---
class WaveNetForSequenceClassification(WaveNetPreTrainedModel):
def __init__(self, config: WaveNetConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.wavenet = WaveNetModel(config)
classifier_dropout = (
config.classifier_dropout
if config.classifier_dropout is not None
else config.dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.wavenet(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True, # Force return_dict for easier access to pooler_output
)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss: Optional[torch.Tensor] = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze().float())
else:
loss = loss_fct(logits, labels.float())
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels.float())
if not return_dict:
output = (
(logits,) + outputs.hidden_states if output_hidden_states else (logits,)
)
# WaveNetModel output tuple is (last_hidden_state, pooler_output, all_hidden_states)
# We want logits + all_hidden_states (if requested)
# outputs from wavenet are BaseModelOutputWithPooling
# outputs.to_tuple() is (last_hidden_state, pooler_output, hidden_states, attentions)
# We want (logits,) + hidden_states_from_core_model
# The original code had `outputs[2:]` which would be (hidden_states, attentions)
# Let's be more explicit:
if output_hidden_states:
output = (logits,) + (outputs.hidden_states,)
else:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
if __name__ == "__main__":
config_params = {
"vocab_size": 1000,
"hidden_size": 64,
"max_position_embeddings": 128,
"num_labels": 3,
"operation_type": "modulation",
"dropout_prob": 0.1,
"output_hidden_states": True,
"pad_token_id": 0, # Explicitly set pad_token_id
"return_dict": True, # Changed from use_return_dict
# "hidden_act": "gelu" # Could be added if WaveNetSingleLayer needs it from config
}
config = WaveNetConfig(**config_params)
print("--- Testing WaveNetForSequenceClassification (Modulation) ---")
model = WaveNetForSequenceClassification(config)
model.eval()
batch_size = 2
seq_len = 10
dummy_input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), dtype=torch.long
)
dummy_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
dummy_attention_mask[0, 5:] = 0 # Pad last 5 tokens of first example
dummy_labels = torch.randint(0, config.num_labels, (batch_size,), dtype=torch.long)
print(f"Input IDs shape: {dummy_input_ids.shape}")
print(f"Attention Mask: \n{dummy_attention_mask}")
with torch.no_grad():
outputs_dict = model(
input_ids=dummy_input_ids,
attention_mask=dummy_attention_mask,
labels=dummy_labels,
return_dict=True, # Test with explicit return_dict=True
)
print(f"\nLogits shape: {outputs_dict.logits.shape}")
print(f"Loss: {outputs_dict.loss}")
if outputs_dict.hidden_states is not None:
print(f"Number of hidden states returned: {len(outputs_dict.hidden_states)}")
print(
f"Shape of initial embedding (hidden_states[0]): {outputs_dict.hidden_states[0].shape}"
)
print(
f"Shape of last hidden state (hidden_states[-1]): {outputs_dict.hidden_states[-1].shape}"
)
else:
print("No hidden states returned.")
print("\n--- Testing Model Saving and Loading ---")
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
loaded_model = WaveNetForSequenceClassification.from_pretrained(tmpdirname)
loaded_model.eval()
# Ensure model is on the same device as inputs for comparison
if dummy_input_ids.device != loaded_model.device:
loaded_model.to(dummy_input_ids.device)
with torch.no_grad():
outputs_loaded = loaded_model(
input_ids=dummy_input_ids,
attention_mask=dummy_attention_mask,
labels=dummy_labels, # Pass labels to ensure loss is comparable if needed
return_dict=True,
)
if torch.allclose(outputs_dict.logits, outputs_loaded.logits, atol=1e-5):
print("Model saving and loading test passed (logits match).")
else:
print("Model saving and loading test FAILED (logits differ).")
print(f"Original logits: {outputs_dict.logits}")
print(f"Loaded logits: {outputs_loaded.logits}")
if outputs_dict.loss is not None and outputs_loaded.loss is not None:
if torch.allclose(outputs_dict.loss, outputs_loaded.loss, atol=1e-5):
print("Model saving and loading test passed (loss match).")
else:
print("Model saving and loading test FAILED (loss differ).")
print(f"Original loss: {outputs_dict.loss}")
print(f"Loaded loss: {outputs_loaded.loss}")
print("\n--- Testing Core WaveNetModel ---")
core_model = WaveNetModel(config)
core_model.eval()
with torch.no_grad():
core_outputs = core_model(
input_ids=dummy_input_ids,
attention_mask=dummy_attention_mask,
return_dict=True,
output_hidden_states=True, # Ensure hidden states are requested
)
print(f"Core model last_hidden_state shape: {core_outputs.last_hidden_state.shape}")
print(f"Core model pooler_output shape: {core_outputs.pooler_output.shape}")
if core_outputs.hidden_states is not None:
print(f"Core model num hidden_states: {len(core_outputs.hidden_states)}")
print(
f"Core model initial embedding shape: {core_outputs.hidden_states[0].shape}"
)
print(
f"Core model final hidden_state shape: {core_outputs.hidden_states[-1].shape}"
)
else:
print("Core model no hidden_states returned.")
print("\n--- Testing with return_dict=False for SequenceClassification ---")
with torch.no_grad():
outputs_tuple = model(
input_ids=dummy_input_ids,
attention_mask=dummy_attention_mask,
labels=dummy_labels,
return_dict=False,
output_hidden_states=True,
)
# Expected tuple: (loss, logits, hidden_states)
print(f"Output tuple type: {type(outputs_tuple)}")
print(f"Length of tuple: {len(outputs_tuple)}")
print(f"Loss (from tuple): {outputs_tuple[0]}")
print(f"Logits shape (from tuple): {outputs_tuple[1].shape}")
if len(outputs_tuple) > 2 and outputs_tuple[2] is not None:
print(f"Number of hidden_states (from tuple): {len(outputs_tuple[2])}")
@pszemraj
Copy link
Author

pszemraj commented May 8, 2025

I tried, you know, having an LLM add docstrings, but of course it screwed up the impl itself doing so.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment