Last active
May 8, 2025 04:01
-
-
Save pszemraj/d5e7c8af0c6b19ecf4551efc0c7f83eb to your computer and use it in GitHub Desktop.
pytorch impl for pretraining-free (directly finetune) wavenet, tiny transformer for classification
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
WaveNet: An Ultra-Small Language Model (PyTorch Implementation) | |
Based on the paper: https://arxiv.org/abs/2411.02674 | |
Hugging Face Transformers compatible implementation. | |
""" | |
import math | |
from typing import Dict, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from transformers.activations import ACT2FN # Use HF activations | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.modeling_outputs import ( | |
BaseModelOutputWithPooling, | |
SequenceClassifierOutput, | |
) | |
from transformers.modeling_utils import PreTrainedModel | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
# --- Configuration --- | |
class WaveNetConfig(PretrainedConfig): | |
""" | |
Configuration class for WaveNet models. Inherits from `PretrainedConfig`. | |
Args: | |
vocab_size (`int`, *optional*, defaults to 30522): Vocabulary size. | |
hidden_size (`int`, *optional*, defaults to 768): Dimension of embeddings and hidden states. | |
max_position_embeddings (`int`, *optional*, defaults to 512): Maximum sequence length. | |
initializer_range (`float`, *optional*, defaults to 0.02): Standard deviation for weight initialization. | |
layer_norm_eps (`float`, *optional*, defaults to 1e-12): Epsilon for LayerNorm. | |
dropout_prob (`float`, *optional*, defaults to 0.1): Dropout probability for embeddings and FFN. | |
operation_type (`str`, *optional*, defaults to "modulation"): Wave superposition type ("interference" or "modulation"). | |
ffn_intermediate_factor (`int`, *optional*, defaults to 4): Factor to determine FFN intermediate size (intermediate = hidden * factor). | |
classifier_dropout (`float`, *optional*, defaults to None): Dropout for the classification head. If None, uses `dropout_prob`. | |
# Other standard PretrainedConfig args (num_labels, pad_token_id, return_dict, output_hidden_states, etc.) | |
# are handled via **kwargs by the PretrainedConfig base class. | |
""" | |
model_type = "wavenet" | |
def __init__( | |
self, | |
vocab_size: int = 30522, | |
hidden_size: int = 768, | |
max_position_embeddings: int = 512, | |
initializer_range: float = 0.02, | |
layer_norm_eps: float = 1e-12, | |
dropout_prob: float = 0.1, | |
operation_type: str = "modulation", | |
ffn_intermediate_factor: int = 4, | |
classifier_dropout: Optional[float] = None, | |
**kwargs, | |
): | |
# Pass standard args (like num_labels, pad_token_id, return_dict, output_*) to base | |
# PretrainedConfig will handle them. | |
super().__init__(**kwargs) | |
# Set model-specific attributes | |
self.vocab_size = vocab_size | |
self.hidden_size = hidden_size | |
self.max_position_embeddings = max_position_embeddings | |
self.initializer_range = initializer_range | |
self.layer_norm_eps = layer_norm_eps | |
self.dropout_prob = dropout_prob | |
self.operation_type = operation_type | |
self.ffn_intermediate_size = hidden_size * ffn_intermediate_factor | |
self.classifier_dropout = ( | |
classifier_dropout if classifier_dropout is not None else dropout_prob | |
) | |
if self.operation_type not in ["interference", "modulation"]: | |
raise ValueError("operation_type must be 'interference' or 'modulation'") | |
# Attributes like self.num_labels, self.pad_token_id, self.return_dict, | |
# self.output_hidden_states, self.output_attentions are set by PretrainedConfig | |
# based on what's in kwargs. For example, if 'num_labels' is in kwargs, | |
# PretrainedConfig will use it, otherwise it will use its own default (e.g., 2). | |
# The same applies to pad_token_id (PretrainedConfig default is None). | |
# --- Helper: Complex Vector Encoder --- | |
class ComplexVectorEncoder(nn.Module): | |
"""Encodes real-valued token embeddings into complex vector representations.""" | |
def __init__(self, config: WaveNetConfig): | |
super().__init__() | |
self.config = config | |
self.hidden_size: int = config.hidden_size | |
def forward( | |
self, | |
real_embeddings: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
""" | |
Args: | |
real_embeddings: Shape `(batch_size, seq_len, hidden_size)` | |
attention_mask: Shape `(batch_size, seq_len)`, 1 for valid, 0 for pad. | |
Returns: | |
complex_vectors: Shape `(batch_size, seq_len, hidden_size)`, dtype=complex | |
""" | |
# 1. Global Semantics (Magnitude G) | |
if attention_mask is not None: | |
# Ensure attention_mask is float for multiplication if embeddings are float | |
expanded_mask = ( | |
attention_mask.unsqueeze(-1) | |
.expand_as(real_embeddings) | |
.to(real_embeddings.dtype) | |
) | |
masked_embeddings = real_embeddings * expanded_mask | |
else: | |
masked_embeddings = real_embeddings | |
# Sum of squares over the sequence length dimension for each hidden dimension | |
# masked_embeddings is (batch_size, seq_len, hidden_size) | |
# We want G_k = sqrt(sum_over_tokens(w_token,k^2)) | |
# So sum should be over dim=1 (seq_len) | |
sum_sq_embeddings_per_dim = torch.sum( | |
masked_embeddings**2, dim=1 | |
) # (batch_size, hidden_size) | |
G = torch.sqrt( | |
sum_sq_embeddings_per_dim.clamp(min=1e-12) | |
) # (batch_size, hidden_size) | |
G_broadcast = G.unsqueeze( | |
1 | |
) # (batch_size, 1, hidden_size), for broadcasting over seq_len | |
G_broadcast_clamped = G_broadcast.clamp(min=1e-6) # Avoid division by zero | |
# 2. Local Semantics (Phase alpha) | |
# real_embeddings is (batch_size, seq_len, hidden_size) | |
# G_broadcast_clamped is (batch_size, 1, hidden_size) | |
cos_val = real_embeddings / G_broadcast_clamped # w_j,k / G_k | |
cos_val = torch.clamp( | |
cos_val, -1.0 + 1e-7, 1.0 - 1e-7 | |
) # Ensure valid input for acos/sqrt | |
# sin_val = -sqrt(1 - cos_val^2) as per paper's atan2(y,x) where y is the -sqrt term | |
sin_val_sq = (1 - cos_val**2).clamp( | |
min=0.0 | |
) # Ensure non-negative before sqrt | |
y_term_for_atan2 = -torch.sqrt(sin_val_sq) | |
x_term_for_atan2 = cos_val | |
alpha = torch.atan2( | |
y_term_for_atan2, x_term_for_atan2 | |
) # (batch_size, seq_len, hidden_size) | |
# 3. Form Complex Vector Z = G * e^(i*alpha) | |
# G_broadcast is (batch_size, 1, hidden_size) | |
# alpha is (batch_size, seq_len, hidden_size) | |
# Result should be (batch_size, seq_len, hidden_size) | |
real_part = G_broadcast * torch.cos(alpha) | |
imag_part = G_broadcast * torch.sin(alpha) | |
complex_vectors = torch.complex(real_part, imag_part) | |
return complex_vectors | |
# --- WaveNet Embeddings (Token + Positional) --- | |
class WaveNetEmbeddings(nn.Module): | |
"""Constructs token and learned positional embeddings.""" | |
def __init__(self, config: WaveNetConfig): | |
super().__init__() | |
self.word_embeddings = nn.Embedding( | |
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id | |
) | |
self.position_embeddings = nn.Embedding( | |
config.max_position_embeddings, config.hidden_size | |
) | |
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.dropout = nn.Dropout(config.dropout_prob) | |
self.register_buffer( | |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) | |
) | |
# For Hugging Face compatibility, to know the max length | |
self.max_position_embeddings = config.max_position_embeddings | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
inputs_embeds: Optional[torch.Tensor] = None, | |
# token_type_ids is not used by WaveNet but might be passed by HF pipelines | |
token_type_ids: Optional[torch.LongTensor] = None, | |
) -> torch.Tensor: | |
if input_ids is not None: | |
input_shape = input_ids.size() | |
elif inputs_embeds is not None: | |
input_shape = inputs_embeds.size()[:-1] | |
else: | |
raise ValueError("You have to specify either input_ids or inputs_embeds") | |
seq_length = input_shape[1] | |
if position_ids is None: | |
# Create position_ids on the fly | |
# Ensure it's on the same device as input_ids or inputs_embeds | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
right_padded_idx = torch.arange(seq_length, dtype=torch.long, device=device) | |
position_ids = right_padded_idx.unsqueeze(0).expand(input_shape) | |
if inputs_embeds is None: | |
inputs_embeds = self.word_embeddings(input_ids) | |
position_embeddings = self.position_embeddings(position_ids) | |
embeddings = inputs_embeds + position_embeddings | |
embeddings = self.LayerNorm(embeddings) | |
embeddings = self.dropout(embeddings) | |
return embeddings | |
# --- WaveNet Core Layer (Single Layer) --- | |
class WaveNetSingleLayer(nn.Module): | |
"""Implements the core WaveNet layer logic.""" | |
def __init__(self, config: WaveNetConfig): | |
super().__init__() | |
self.config = config | |
self.hidden_size: int = config.hidden_size | |
self.variant_project1 = nn.Linear(config.hidden_size, config.hidden_size) | |
self.variant_project2 = nn.Linear(config.hidden_size, config.hidden_size) | |
self.complex_encoder = ComplexVectorEncoder(config) | |
self.ffn_dense1 = nn.Linear( | |
2 * config.hidden_size, | |
config.ffn_intermediate_size, # Input is concat(real, imag) | |
) | |
self.ffn_activation = ACT2FN[ | |
self.config.hidden_act if hasattr(self.config, "hidden_act") else "gelu" | |
] | |
self.ffn_dense2 = nn.Linear(config.ffn_intermediate_size, config.hidden_size) | |
self.output_dropout = nn.Dropout(config.dropout_prob) | |
self.output_layernorm = nn.LayerNorm( | |
config.hidden_size, eps=config.layer_norm_eps | |
) | |
def forward( | |
self, | |
initial_embeddings: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, # This is the padding mask | |
) -> torch.Tensor: | |
"""Processes input embeddings through the WaveNet layer.""" | |
# Project to get inputs for two variants | |
projected_embeddings1 = self.variant_project1(initial_embeddings) | |
projected_embeddings2 = self.variant_project2(initial_embeddings) | |
# Generate complex vectors | |
# The attention_mask here is the padding mask, used by ComplexVectorEncoder | |
# to correctly calculate global semantics G | |
Z1 = self.complex_encoder(projected_embeddings1, attention_mask=attention_mask) | |
Z2 = self.complex_encoder(projected_embeddings2, attention_mask=attention_mask) | |
# Wave Superposition | |
if self.config.operation_type == "interference": | |
Z_combined = Z1 + Z2 | |
elif self.config.operation_type == "modulation": | |
Z_combined = Z1 * Z2 | |
else: | |
raise ValueError(f"Unknown operation_type: {self.config.operation_type}") | |
# Feed Forward Network on real and imaginary parts | |
ffn_input = torch.cat((Z_combined.real, Z_combined.imag), dim=-1) | |
hidden_states_ffn = self.ffn_dense1(ffn_input) | |
hidden_states_ffn = self.ffn_activation(hidden_states_ffn) | |
hidden_states_ffn = self.ffn_dense2(hidden_states_ffn) | |
# Apply dropout, residual connection, then LayerNorm (Post-LN style) | |
hidden_states = initial_embeddings + self.output_dropout(hidden_states_ffn) | |
processed_representation = self.output_layernorm(hidden_states) | |
return processed_representation | |
# --- Base Model Class --- | |
class WaveNetPreTrainedModel(PreTrainedModel): | |
config_class = WaveNetConfig | |
base_model_prefix = "wavenet" | |
supports_gradient_checkpointing = True | |
_no_split_modules = ["WaveNetEmbeddings", "WaveNetSingleLayer"] | |
def _init_weights(self, module: nn.Module): | |
std = self.config.initializer_range | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if isinstance(module, WaveNetModel): | |
module.gradient_checkpointing = value | |
# --- Core WaveNet Model --- | |
class WaveNetModel(WaveNetPreTrainedModel): | |
def __init__(self, config: WaveNetConfig): | |
super().__init__(config) | |
self.config = config | |
self.embeddings = WaveNetEmbeddings(config) | |
self.wave_layer = WaveNetSingleLayer(config) | |
self.gradient_checkpointing = False | |
self.post_init() | |
def get_input_embeddings(self) -> nn.Embedding: | |
return self.embeddings.word_embeddings | |
def set_input_embeddings(self, value: nn.Embedding): | |
self.embeddings.word_embeddings = value | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
inputs_embeds: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPooling]: | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError( | |
"You cannot specify both input_ids and inputs_embeds at the same time" | |
) | |
elif input_ids is None and inputs_embeds is None: | |
raise ValueError("You have to specify either input_ids or inputs_embeds") | |
if input_ids is not None: | |
device = input_ids.device | |
batch_size, seq_length = input_ids.shape | |
else: | |
device = inputs_embeds.device | |
batch_size, seq_length = inputs_embeds.shape[:2] | |
if attention_mask is None: | |
attention_mask = torch.ones( | |
((batch_size, seq_length)), device=device, dtype=torch.long | |
) | |
# WaveNetEmbeddings expects token_type_ids, even if unused by its logic, for HF compatibility | |
initial_embeddings = self.embeddings( | |
input_ids=input_ids, | |
position_ids=position_ids, | |
inputs_embeds=inputs_embeds, | |
token_type_ids=token_type_ids, | |
) | |
hidden_states_current_layer = initial_embeddings | |
all_hidden_states = (initial_embeddings,) if output_hidden_states else None | |
if self.gradient_checkpointing and self.training: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) # attention_mask is the second input | |
return custom_forward | |
hidden_states_current_layer = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(self.wave_layer), | |
hidden_states_current_layer, | |
attention_mask, # Pass the padding mask | |
use_reentrant=self.config.get( | |
"use_reentrant", False | |
), # PyTorch default is True, HF often sets False | |
) | |
else: | |
hidden_states_current_layer = self.wave_layer( | |
hidden_states_current_layer, attention_mask=attention_mask | |
) | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states_current_layer,) | |
pooled_output = hidden_states_current_layer[:, 0] # CLS token pooling | |
if not return_dict: | |
outputs = (hidden_states_current_layer, pooled_output) | |
if output_hidden_states: | |
outputs = outputs + (all_hidden_states,) | |
return outputs | |
return BaseModelOutputWithPooling( | |
last_hidden_state=hidden_states_current_layer, | |
pooler_output=pooled_output, | |
hidden_states=all_hidden_states, | |
attentions=None, # WaveNet does not use attention | |
) | |
# --- Classification Head Model --- | |
class WaveNetForSequenceClassification(WaveNetPreTrainedModel): | |
def __init__(self, config: WaveNetConfig): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.config = config | |
self.wavenet = WaveNetModel(config) | |
classifier_dropout = ( | |
config.classifier_dropout | |
if config.classifier_dropout is not None | |
else config.dropout_prob | |
) | |
self.dropout = nn.Dropout(classifier_dropout) | |
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
self.post_init() | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
inputs_embeds: Optional[torch.Tensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutput]: | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
outputs = self.wavenet( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=True, # Force return_dict for easier access to pooler_output | |
) | |
pooled_output = outputs.pooler_output | |
pooled_output = self.dropout(pooled_output) | |
logits = self.classifier(pooled_output) | |
loss: Optional[torch.Tensor] = None | |
if labels is not None: | |
if self.config.problem_type is None: | |
if self.num_labels == 1: | |
self.config.problem_type = "regression" | |
elif self.num_labels > 1 and ( | |
labels.dtype == torch.long or labels.dtype == torch.int | |
): | |
self.config.problem_type = "single_label_classification" | |
else: | |
self.config.problem_type = "multi_label_classification" | |
if self.config.problem_type == "regression": | |
loss_fct = MSELoss() | |
if self.num_labels == 1: | |
loss = loss_fct(logits.squeeze(), labels.squeeze().float()) | |
else: | |
loss = loss_fct(logits, labels.float()) | |
elif self.config.problem_type == "single_label_classification": | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
elif self.config.problem_type == "multi_label_classification": | |
loss_fct = BCEWithLogitsLoss() | |
loss = loss_fct(logits, labels.float()) | |
if not return_dict: | |
output = ( | |
(logits,) + outputs.hidden_states if output_hidden_states else (logits,) | |
) | |
# WaveNetModel output tuple is (last_hidden_state, pooler_output, all_hidden_states) | |
# We want logits + all_hidden_states (if requested) | |
# outputs from wavenet are BaseModelOutputWithPooling | |
# outputs.to_tuple() is (last_hidden_state, pooler_output, hidden_states, attentions) | |
# We want (logits,) + hidden_states_from_core_model | |
# The original code had `outputs[2:]` which would be (hidden_states, attentions) | |
# Let's be more explicit: | |
if output_hidden_states: | |
output = (logits,) + (outputs.hidden_states,) | |
else: | |
output = (logits,) | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
if __name__ == "__main__": | |
config_params = { | |
"vocab_size": 1000, | |
"hidden_size": 64, | |
"max_position_embeddings": 128, | |
"num_labels": 3, | |
"operation_type": "modulation", | |
"dropout_prob": 0.1, | |
"output_hidden_states": True, | |
"pad_token_id": 0, # Explicitly set pad_token_id | |
"return_dict": True, # Changed from use_return_dict | |
# "hidden_act": "gelu" # Could be added if WaveNetSingleLayer needs it from config | |
} | |
config = WaveNetConfig(**config_params) | |
print("--- Testing WaveNetForSequenceClassification (Modulation) ---") | |
model = WaveNetForSequenceClassification(config) | |
model.eval() | |
batch_size = 2 | |
seq_len = 10 | |
dummy_input_ids = torch.randint( | |
0, config.vocab_size, (batch_size, seq_len), dtype=torch.long | |
) | |
dummy_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) | |
dummy_attention_mask[0, 5:] = 0 # Pad last 5 tokens of first example | |
dummy_labels = torch.randint(0, config.num_labels, (batch_size,), dtype=torch.long) | |
print(f"Input IDs shape: {dummy_input_ids.shape}") | |
print(f"Attention Mask: \n{dummy_attention_mask}") | |
with torch.no_grad(): | |
outputs_dict = model( | |
input_ids=dummy_input_ids, | |
attention_mask=dummy_attention_mask, | |
labels=dummy_labels, | |
return_dict=True, # Test with explicit return_dict=True | |
) | |
print(f"\nLogits shape: {outputs_dict.logits.shape}") | |
print(f"Loss: {outputs_dict.loss}") | |
if outputs_dict.hidden_states is not None: | |
print(f"Number of hidden states returned: {len(outputs_dict.hidden_states)}") | |
print( | |
f"Shape of initial embedding (hidden_states[0]): {outputs_dict.hidden_states[0].shape}" | |
) | |
print( | |
f"Shape of last hidden state (hidden_states[-1]): {outputs_dict.hidden_states[-1].shape}" | |
) | |
else: | |
print("No hidden states returned.") | |
print("\n--- Testing Model Saving and Loading ---") | |
import tempfile | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
model.save_pretrained(tmpdirname) | |
loaded_model = WaveNetForSequenceClassification.from_pretrained(tmpdirname) | |
loaded_model.eval() | |
# Ensure model is on the same device as inputs for comparison | |
if dummy_input_ids.device != loaded_model.device: | |
loaded_model.to(dummy_input_ids.device) | |
with torch.no_grad(): | |
outputs_loaded = loaded_model( | |
input_ids=dummy_input_ids, | |
attention_mask=dummy_attention_mask, | |
labels=dummy_labels, # Pass labels to ensure loss is comparable if needed | |
return_dict=True, | |
) | |
if torch.allclose(outputs_dict.logits, outputs_loaded.logits, atol=1e-5): | |
print("Model saving and loading test passed (logits match).") | |
else: | |
print("Model saving and loading test FAILED (logits differ).") | |
print(f"Original logits: {outputs_dict.logits}") | |
print(f"Loaded logits: {outputs_loaded.logits}") | |
if outputs_dict.loss is not None and outputs_loaded.loss is not None: | |
if torch.allclose(outputs_dict.loss, outputs_loaded.loss, atol=1e-5): | |
print("Model saving and loading test passed (loss match).") | |
else: | |
print("Model saving and loading test FAILED (loss differ).") | |
print(f"Original loss: {outputs_dict.loss}") | |
print(f"Loaded loss: {outputs_loaded.loss}") | |
print("\n--- Testing Core WaveNetModel ---") | |
core_model = WaveNetModel(config) | |
core_model.eval() | |
with torch.no_grad(): | |
core_outputs = core_model( | |
input_ids=dummy_input_ids, | |
attention_mask=dummy_attention_mask, | |
return_dict=True, | |
output_hidden_states=True, # Ensure hidden states are requested | |
) | |
print(f"Core model last_hidden_state shape: {core_outputs.last_hidden_state.shape}") | |
print(f"Core model pooler_output shape: {core_outputs.pooler_output.shape}") | |
if core_outputs.hidden_states is not None: | |
print(f"Core model num hidden_states: {len(core_outputs.hidden_states)}") | |
print( | |
f"Core model initial embedding shape: {core_outputs.hidden_states[0].shape}" | |
) | |
print( | |
f"Core model final hidden_state shape: {core_outputs.hidden_states[-1].shape}" | |
) | |
else: | |
print("Core model no hidden_states returned.") | |
print("\n--- Testing with return_dict=False for SequenceClassification ---") | |
with torch.no_grad(): | |
outputs_tuple = model( | |
input_ids=dummy_input_ids, | |
attention_mask=dummy_attention_mask, | |
labels=dummy_labels, | |
return_dict=False, | |
output_hidden_states=True, | |
) | |
# Expected tuple: (loss, logits, hidden_states) | |
print(f"Output tuple type: {type(outputs_tuple)}") | |
print(f"Length of tuple: {len(outputs_tuple)}") | |
print(f"Loss (from tuple): {outputs_tuple[0]}") | |
print(f"Logits shape (from tuple): {outputs_tuple[1].shape}") | |
if len(outputs_tuple) > 2 and outputs_tuple[2] is not None: | |
print(f"Number of hidden_states (from tuple): {len(outputs_tuple[2])}") |
I tried, you know, having an LLM add docstrings, but of course it screwed up the impl itself doing so.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
running tests
output looks excellent and indicates that the model structure and Hugging Face integration are working correctly for this single-layer WaveNet.
Let's break down why:
Input Shapes & Masking:
Input IDs shape: torch.Size([2, 10])
: Correct for batch size 2, sequence length 10.Attention Mask
: Correctly shows the first sequence is padded (zeros from index 5 onwards) and the second is not. This is crucial for theComplexVectorEncoder
to correctly calculate global semantics.WaveNetForSequenceClassification
(withreturn_dict=True
):Logits shape: torch.Size([2, 3])
: Correct (batch_size, num_labels).Loss: 1.042...
: A reasonable scalar loss value for an untrained model with 3 classes (expected loss for random guessing would be -ln(1/3) approx 1.098).Number of hidden states returned: 2
: This is correct. Sinceoutput_hidden_states=True
and it's a single-layer model:hidden_states[0]
is the initial embedding output fromWaveNetEmbeddings
.hidden_states[1]
is the output of the singleWaveNetSingleLayer
.Shape of initial embedding (hidden_states[0]): torch.Size([2, 10, 64])
: Correct (batch_size, seq_len, hidden_size).Shape of last hidden state (hidden_states[-1]): torch.Size([2, 10, 64])
: Correct, this is the output of theWaveNetSingleLayer
.Model Saving and Loading:
Model saving and loading test passed (logits match).
Model saving and loading test passed (loss match).
Core
WaveNetModel
(withreturn_dict=True
):Core model last_hidden_state shape: torch.Size([2, 10, 64])
: Correct.Core model pooler_output shape: torch.Size([2, 64])
: Correct (batch_size, hidden_size), as it takes the first token's representation from thelast_hidden_state
.Core model num hidden_states: 2
: Correct, for the same reasons as above (initial embeddings + one layer output).Core model initial embedding shape: torch.Size([2, 10, 64])
: Correct.Core model final hidden_state shape: torch.Size([2, 10, 64])
: Correct.WaveNetForSequenceClassification
(withreturn_dict=False
):Output tuple type: <class 'tuple'>
: Correct.Length of tuple: 3
: Correct. Whenlabels
are provided andoutput_hidden_states=True
, the expected tuple is(loss, logits, hidden_states)
.Loss (from tuple): 1.042...
: Matches the loss from thereturn_dict=True
case, which is consistent.Logits shape (from tuple): torch.Size([2, 3])
: Correct.Number of hidden_states (from tuple): 2
: Correct.Key Takeaways from the Output:
output_hidden_states
Works: The model correctly returns the initial embeddings and the output of the single wave layer when requested.return_dict
Works: BothTrue
andFalse
modes forreturn_dict
behave as expected in the classification model.pooler_output
in the core model is correctly extracted.The output strongly suggests that the implementation of the WaveNet components (embeddings, complex vector encoding, wave superposition, FFN, residual connections, and layer norms) and their integration into the Hugging Face
PreTrainedModel
structure are sound.