Skip to content

Instantly share code, notes, and snippets.

@Gal-Lahat
Created April 9, 2025 16:33
Show Gist options
  • Save Gal-Lahat/53ff2751ccd82329ae5f5189ffb4e525 to your computer and use it in GitHub Desktop.
Save Gal-Lahat/53ff2751ccd82329ae5f5189ffb4e525 to your computer and use it in GitHub Desktop.
Latent Manipulator Cookbook.md
This guide explains the "Latent Manipulator," an experimental AI architecture designed to "think" in a latent space before generating text, contrasting with standard Transformer models that predict text sequentially. It includes the theory, code for implementation, and links to resources.
Based on the video exploring this concept: [https://www.youtube.com/watch?v=fWiieyG2zes]
## Why a Latent Manipulator? The Theory
Standard Large Language Models (LLMs) like ChatGPT are typically based on the **Transformer** architecture. Their core operation involves predicting the *next word* (or token) in a sequence, given all the preceding words. This means their process of "thinking" or reasoning is intertwined with the act of generating text word-by-word. If you ask ChatGPT, "Can you think quietly before writing?", it might say yes, but architecturally, it *can't* – its computation *is* the generation process.
Humans, however, can often form an "idea" or grasp the semantics of a concept before finding the exact words to express it. The Latent Manipulator architecture attempts to mimic this separation:
1. **Idea Space (Latent Space):** We need a way to represent the *meaning* or *idea* of a piece of text numerically, separate from the text itself. This is achieved using an **Autoencoder**.
* **Encoder:** Takes text as input and compresses it into a dense numerical vector (e.g., 1024 numbers). This vector lives in the "latent space" and represents the "idea" of the input text.
* **Decoder:** Takes a vector from the latent space and reconstructs the original text (or a close approximation).
* **Bottleneck:** The crucial part is the "bottleneck" in the middle of the autoencoder (the latent space itself), which forces the model to learn a compact, meaningful representation of the input.
2. **The Latent Manipulator (Thinking Engine):** This is a separate model (which doesn't *have* to be a Transformer) that operates *entirely within the latent space*.
* It takes the latent vector representing the *question* (generated by the Encoder).
* It performs computations on this vector to transform it into a *new* latent vector representing the *answer*.
* This transformation is the "thinking" process, happening *without generating any text*.
3. **Generating the Answer:** The resulting latent vector (the "idea" of the answer) is then fed into the **Decoder** part of the autoencoder, which converts this "idea" back into human-readable text.
**In essence: Text Question -> Encoder -> Latent Question -> Latent Manipulator -> Latent Answer -> Decoder -> Text Answer.**
This separation offers potential advantages:
* **True "Thinking":** Allows computation on semantic meaning before articulation.
* **Multilingual Potential:** The latent space could potentially become language-agnostic. You could train different Encoder/Decoder pairs for various languages but use the *same* Latent Manipulator for reasoning, promoting consistency across languages.
* **Efficiency & Control:** Manipulating smaller latent vectors might be more efficient than full text generation for certain reasoning tasks.
## Implementation Guide
### Prerequisites
* Python 3.12.2 (or compatible)
* Transformers library (version 4.37.2 used here)
* PyTorch
* Pandas & PyArrow (for data preparation)
* NumPy
* tqdm (for progress bars)
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import json
import numpy as np
from tqdm import tqdm
import pandas as pd
import glob
import pyarrow
```
### 1. Setup: Loading the Autoencoder
We use a pre-trained T5 model that has been adapted into a bottlenecked autoencoder. This model provides the `embed` (text to latent) and `generate_from_latent` (latent to text) functionalities.
```python
# Define the Autoencoder Abstraction (Helper Class)
# Note: This requires the specific model code from 'thesephist/contra-bottleneck-t5-large-wikipedia'
# Ensure you have 'trust_remote_code=True' when loading if needed.
class BottleneckT5Autoencoder:
def __init__(self, model_path: str, device='cpu'):
self.device = device
print(f"Using device: {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512)
# Ensure trust_remote_code=True if the model requires custom code
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device)
self.model.eval() # Set to evaluation mode
@torch.no_grad()
def embed(self, text: str) -> torch.FloatTensor:
"""Encodes a single string into a latent embedding."""
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512).to(self.device)
# Decoder input starts with the beginning-of-sequence token for T5-like models
decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], dtype=torch.long).to(self.device)
# Generate the latent embedding
outputs = self.model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
decoder_input_ids=decoder_input_ids, # Provide initial decoder input
encode_only=True, # Flag to get the bottleneck representation
)
# The exact output structure might vary; inspect 'outputs' if needed.
# Assuming the latent vector is the first element.
return outputs[0]
@torch.no_grad()
def embed_batch(self, texts: list[str]) -> torch.FloatTensor:
"""Encodes a batch of strings into latent embeddings."""
inputs = self.tokenizer(
texts,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
# Prepare decoder start tokens for the batch
decoder_start_token_id = self.model.config.decoder_start_token_id
if decoder_start_token_id is None:
decoder_start_token_id = self.tokenizer.pad_token_id # Fallback if not defined
decoder_input_ids = torch.full(
(len(texts), 1),
decoder_start_token_id,
dtype=torch.long,
device=self.device
)
outputs = self.model(
**inputs,
decoder_input_ids=decoder_input_ids,
encode_only=True,
)
return outputs[0]
@torch.no_grad()
def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=0.4) -> str:
"""Decodes a latent embedding back into text."""
# Ensure latent is on the correct device and has a batch dimension
if latent.dim() == 1:
latent = latent.unsqueeze(0)
latent = latent.to(self.device)
# Use the model's generate method with the latent vector
# This relies on the custom model code handling the 'latent_vector' parameter
output_sequences = self.model.generate(
encoder_outputs=None, # We provide latent directly, not standard encoder outputs
latent_vector=latent, # Custom argument for this specific model
max_length=max_length,
do_sample=True,
temperature=temperature,
top_p=0.9,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id # Important for stopping generation
)
# Decode the first sequence
return self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
```
```python
# --- Initialize the Autoencoder ---
# Set your Hugging Face token if needed
# os.environ["HF_TOKEN"] = "your_huggingface_token"
# Determine device (adjust as needed: 'cuda', 'mps', 'cpu')
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available(): # For Apple Silicon
device = 'mps'
else:
device = 'cpu'
# Load the pre-trained autoencoder model
model_path = 'thesephist/contra-bottleneck-t5-large-wikipedia'
autoencoder = BottleneckT5Autoencoder(model_path=model_path, device=device)
print("Autoencoder loaded successfully.")
```
### 2. Data Preparation
The training data for the Latent Manipulator consists of pairs of latent embeddings: one for the instruction (question) and one for the response (answer).
#### a) Combining Raw Data (if needed)
The example uses the LaMini dataset, provided as Parquet files. If your data is similar, you first need to combine these into a single file (e.g., JSONL).
```python
# --- Combine Parquet files into a single JSONL (Example) ---
parquet_dir = "/path/to/your/parquet/files/" # Directory containing the .parquet files
output_jsonl_file = "/path/to/save/merged_output.jsonl"
parquet_files = glob.glob(os.path.join(parquet_dir, "train-*.parquet"))
print(f"Found Parquet files: {parquet_files}")
# Use 'append' mode to allow resuming if interrupted
with open(output_jsonl_file, "a", encoding="utf-8") as outfile:
for file_path in tqdm(parquet_files, desc="Processing Parquet files"):
try:
df = pd.read_parquet(file_path)
print(f"Read {len(df)} rows from {os.path.basename(file_path)}")
# Iterate through DataFrame rows and write as JSON lines
for _, row in tqdm(df.iterrows(), total=len(df), desc="Writing JSONL", leave=False):
# Ensure columns 'instruction' and 'response' exist
if 'instruction' in row and 'response' in row:
json_record = json.dumps({"instruction": row['instruction'], "response": row['response']})
outfile.write(json_record + '\n')
else:
print(f"Skipping row due to missing columns: {row.to_dict()}")
except Exception as e:
print(f"Error processing {file_path}: {e}")
print(f"Finished merging Parquet files into {output_jsonl_file}")
```
#### b) Generating Embeddings
Convert the text instruction/response pairs from the JSONL file into latent embeddings using the autoencoder and save them to a NumPy file. This uses checkpointing to handle large datasets and allow resuming.
```python
# --- Generate Embeddings from JSONL ---
# Configuration
checkpoint_interval = 100_000 # Save progress every N lines
input_jsonl_file = "/path/to/save/merged_output.jsonl" # Input JSONL file from previous step
embeddings_file = "/path/to/save/embeddings.npy" # Output NumPy file
checkpoint_file = "/path/to/save/checkpoint.txt" # Checkpoint file
batch_size = 64 # Adjust based on your GPU memory
# --- Function to process batches ---
def process_batch(batch_texts_instructions, batch_texts_responses, autoencoder_model):
if not batch_texts_instructions or not batch_texts_responses:
return [], []
try:
# Ensure the model's embed_batch method handles lists of texts
instr_embeddings = autoencoder_model.embed_batch(batch_texts_instructions)
resp_embeddings = autoencoder_model.embed_batch(batch_texts_responses)
# Move embeddings to CPU before converting to NumPy
return instr_embeddings.cpu().numpy(), resp_embeddings.cpu().numpy()
except Exception as e:
print(f"Error processing batch: {e}")
# Optionally, try processing item by item as a fallback
instr_np, resp_np = [], []
for instr, resp in zip(batch_texts_instructions, batch_texts_responses):
try:
instr_emb = autoencoder_model.embed(instr).cpu().numpy()
resp_emb = autoencoder_model.embed(resp).cpu().numpy()
instr_np.append(instr_emb)
resp_np.append(resp_emb)
except Exception as item_e:
print(f"Error processing item '{instr[:50]}...': {item_e}")
return np.array(instr_np) if instr_np else [], np.array(resp_np) if resp_np else []
# --- Main Embedding Generation Logic ---
try:
with open(input_jsonl_file, "r", encoding="utf-8") as f:
total_lines = sum(1 for _ in f)
print(f"Total lines in file: {total_lines}")
except FileNotFoundError:
print(f"Error: Input JSONL file not found at {input_jsonl_file}")
exit()
if os.path.exists(checkpoint_file):
with open(checkpoint_file, "r") as f:
last_processed_line = int(f.read().strip())
print(f"Resuming from line: {last_processed_line + 1}")
else:
last_processed_line = 0
print("No checkpoint found. Starting from the beginning.")
if os.path.exists(embeddings_file) and last_processed_line > 0:
# Load only if resuming and file exists
try:
existing_embeddings = np.load(embeddings_file)
# Ensure we only load embeddings corresponding to processed lines
num_expected_embeddings = last_processed_line * 2
if existing_embeddings.shape[0] >= num_expected_embeddings:
embeddings_list = list(existing_embeddings[:num_expected_embeddings])
print(f"Loaded {len(embeddings_list)} embeddings from previous run (up to line {last_processed_line}).")
else:
print("Warning: Embedding file size doesn't match checkpoint. Starting embeddings list fresh.")
embeddings_list = []
last_processed_line = 0 # Reset checkpoint if mismatch
except Exception as e:
print(f"Error loading existing embeddings: {e}. Starting fresh.")
embeddings_list = []
last_processed_line = 0 # Reset checkpoint on load error
else:
embeddings_list = []
if last_processed_line > 0:
print("Warning: Checkpoint found but no embeddings file. Resetting checkpoint.")
last_processed_line = 0 # Reset checkpoint if no embedding file
overall_line_count = last_processed_line
batch_instructions: list[str] = []
batch_responses: list[str] = []
try:
with open(input_jsonl_file, "r", encoding="utf-8") as f:
# Skip lines up to the checkpoint
for _ in range(last_processed_line):
next(f)
# Process remaining lines with tqdm
pbar = tqdm(f, total=total_lines, initial=last_processed_line, unit="line", desc="Processing lines")
for line in pbar:
overall_line_count += 1
try:
obj = json.loads(line)
if "instruction" in obj and "response" in obj:
batch_instructions.append(obj["instruction"])
batch_responses.append(obj["response"])
else:
print(f"Skipping line {overall_line_count}: Missing 'instruction' or 'response'.")
continue # Skip this line
# Process batch when full
if len(batch_instructions) == batch_size:
instr_np, resp_np = process_batch(batch_instructions, batch_responses, autoencoder)
# Add pairs to the list
for i in range(len(instr_np)):
embeddings_list.append(instr_np[i])
embeddings_list.append(resp_np[i])
# Clear batches
batch_instructions = []
batch_responses = []
# Checkpoint saving logic
if overall_line_count % checkpoint_interval == 0:
current_line_to_save = overall_line_count
with open(checkpoint_file, "w") as cf:
cf.write(str(current_line_to_save))
np.save(embeddings_file, np.array(embeddings_list, dtype=np.float32))
pbar.set_postfix_str(f"Checkpoint saved at line {current_line_to_save}")
except json.JSONDecodeError:
print(f"Skipping line {overall_line_count}: Invalid JSON.")
except Exception as e:
print(f"Error on line {overall_line_count}: {e}")
# Process any remaining items in the last batch
if batch_instructions:
instr_np, resp_np = process_batch(batch_instructions, batch_responses, autoencoder)
for i in range(len(instr_np)):
embeddings_list.append(instr_np[i])
embeddings_list.append(resp_np[i])
# Final save
with open(checkpoint_file, "w") as cf:
cf.write(str(overall_line_count))
np.save(embeddings_file, np.array(embeddings_list, dtype=np.float32))
print(f"Processing complete. Final line count: {overall_line_count}/{total_lines}. Embeddings saved.")
except FileNotFoundError:
print(f"Error: Input JSONL file not found at {input_jsonl_file}")
except Exception as e:
print(f"An unexpected error occurred: {e}")
# Save progress even on error
with open(checkpoint_file, "w") as cf:
cf.write(str(overall_line_count - len(batch_instructions))) # Save last fully processed line
if embeddings_list:
np.save(embeddings_file, np.array(embeddings_list, dtype=np.float32))
print("Saved progress before exiting due to error.")
```
### 3. Latent Manipulator Model Code
This defines the neural network that learns to map question embeddings to answer embeddings. The architecture shown uses multiple feed-forward layers with skip-connection-like features (concatenating intermediate "choked" outputs).
* **Input:** 1024-dimensional latent vector (from Encoder).
* **Architecture:** A series of Linear layers, BatchNorm, LeakyReLU, and Dropout. Intermediate outputs are "choked" (reduced) to 2048 dimensions and concatenated with the original input before final aggregation layers. This complex structure aims to handle deep transformations while mitigating vanishing/exploding gradients.
* **Output:** 1024-dimensional latent vector (to be fed to Decoder).
*(Code provided by the user, slightly adjusted for clarity and comments)*
```python
# --- Latent Manipulator Model Definition ---
class LatentManipulator(nn.Module):
"""
A Feed-Forward Network designed to manipulate latent embeddings.
Takes a 1024-dim embedding and outputs a 1024-dim embedding.
Uses intermediate layer outputs (choked) and concatenation for richness.
"""
def __init__(self, dropout_rate=0.2): # Reduced dropout from original
super(LatentManipulator, self).__init__()
# --- Main Layers (Expand -> Contract) ---
self.layer1 = nn.Sequential(nn.Linear(1024, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.layer2 = nn.Sequential(nn.Linear(2048, 4096), nn.BatchNorm1d(4096), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.layer3 = nn.Sequential(nn.Linear(4096, 6144), nn.BatchNorm1d(6144), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.layer4 = nn.Sequential(nn.Linear(6144, 9216), nn.BatchNorm1d(9216), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) # Widest layer
self.layer5 = nn.Sequential(nn.Linear(9216, 6144), nn.BatchNorm1d(6144), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.layer6 = nn.Sequential(nn.Linear(6144, 4096), nn.BatchNorm1d(4096), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.layer7 = nn.Sequential(nn.Linear(4096, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
# --- Choke Layers (Reduce intermediate outputs to 2048) ---
# These act like shortcuts, bringing information from earlier layers forward.
self.choke1 = nn.Sequential(nn.Linear(2048, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.choke2 = nn.Sequential(nn.Linear(4096, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.choke3 = nn.Sequential(nn.Linear(6144, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.choke4 = nn.Sequential(nn.Linear(9216, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.choke5 = nn.Sequential(nn.Linear(6144, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.choke6 = nn.Sequential(nn.Linear(4096, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.choke7 = nn.Sequential(nn.Linear(2048, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
# --- Aggregation Layers (Combine concatenated features) ---
# Total input size = 1024 (original input) + 7 * 2048 (choked outputs) = 15360
self.aLayer1 = nn.Sequential(nn.Linear(15360, 8192), nn.BatchNorm1d(8192), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
self.aLayer2 = nn.Sequential(nn.Linear(8192, 4096), nn.BatchNorm1d(4096), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
# --- Final Output Layer ---
self.output_layer = nn.Linear(4096, 1024) # Output matches input dimension
def forward(self, x):
# Pass through main layers
x1 = self.layer1(x); x2 = self.layer2(x1); x3 = self.layer3(x2)
x4 = self.layer4(x3); x5 = self.layer5(x4); x6 = self.layer6(x5)
x7 = self.layer7(x6)
# Apply choke layers
c1 = self.choke1(x1); c2 = self.choke2(x2); c3 = self.choke3(x3)
c4 = self.choke4(x4); c5 = self.choke5(x5); c6 = self.choke6(x6)
c7 = self.choke7(x7)
# Concatenate original input and all choked outputs
concat = torch.cat([x, c1, c2, c3, c4, c5, c6, c7], dim=1) # Dim 1 for batch processing
# Pass through aggregation layers
out = self.aLayer1(concat)
out = self.aLayer2(out)
out = self.output_layer(out)
return out
# --- Helper: Weight Initialization ---
def init_weights(m):
"""Applies Kaiming Normal initialization for LeakyReLU."""
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# --- Helper: Count Parameters ---
def count_parameters(model):
"""Counts the number of trainable parameters in a model."""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
```
### 4. Training the Latent Manipulator
This involves setting up a dataset to efficiently load the pre-computed embeddings and a training loop.
#### a) Dataset Class
Uses NumPy's memory mapping (`mmap_mode='r'`) to avoid loading the entire (potentially huge) embeddings file into RAM. It loads only the required question-answer pair for each `__getitem__` call.
```python
# --- NumPy Embedding Dataset ---
class NPYEmbeddingDataset(Dataset):
"""
Lazily loads pairs of embeddings (instruction, response) from a NumPy file
using memory mapping for efficiency with large files.
Assumes embeddings are stored sequentially: [instr1, resp1, instr2, resp2, ...].
"""
def __init__(self, npy_file):
self.npy_file = npy_file
# Load with mmap_mode to get shape/dtype without loading all data
try:
with np.load(npy_file, mmap_mode='r') as data:
self.shape = data.shape
self.dtype = data.dtype
except FileNotFoundError:
print(f"Error: NPY embedding file not found at {npy_file}")
raise
except Exception as e:
print(f"Error loading NPY file: {e}")
raise
if len(self.shape) != 2 or self.shape[0] % 2 != 0:
raise ValueError(f"Expected a 2D numpy array with an even number of rows (embeddings). Got shape: {self.shape}")
self.num_pairs = self.shape[0] // 2
self.embedding_dim = self.shape[1]
print(f"Dataset initialized: {self.num_pairs} pairs, embedding dim {self.embedding_dim}")
def __len__(self):
"""Returns the number of instruction-response pairs."""
return self.num_pairs
def __getitem__(self, idx):
"""Loads the idx-th instruction and response embedding pair."""
if idx >= self.num_pairs:
raise IndexError("Index out of bounds")
# Load with mmap_mode again inside __getitem__ for multi-process loading safety
# This ensures each worker gets its own file handle if num_workers > 0
data = np.load(self.npy_file, mmap_mode='r')
# Calculate row indices for the pair
q_idx = idx * 2
a_idx = idx * 2 + 1
# Extract the embeddings and convert to tensors
q_emb = torch.from_numpy(data[q_idx].copy()).float() # Use .copy() with mmap
a_emb = torch.from_numpy(data[a_idx].copy()).float()
return q_emb, a_emb
```
#### b) Training Loop
Standard PyTorch training loop using the defined dataset and model. Key features:
* **Loss:** Mean Squared Error (`MSELoss`) because we are comparing output embeddings to target embeddings (regression).
* **Optimizer:** AdamW is a good default choice.
* **Learning Rate Scheduling:** `ReduceLROnPlateau` adjusts the learning rate based on validation loss (or average training loss here) stagnation. A *warmup* phase is also added to start with a lower LR and gradually increase it, improving stability early in training.
* **Gradient Clipping:** Prevents exploding gradients, crucial for deep networks.
* **Checkpointing:** Saves the model state periodically, especially when the loss improves.
```python
# --- Utility to save checkpoints ---
def save_checkpoint(state, filename="checkpoint.pt"):
"""Saves model and optimizer state."""
try:
torch.save(state, filename)
print(f"Checkpoint saved to {filename}")
except Exception as e:
print(f"Error saving checkpoint: {e}")
# --- Training Function ---
def train(model, dataloader, epochs=10, base_lr=1e-4, warmup_epochs=1, clip_value=5.0, device=None, checkpoint_dir="checkpoints"):
"""Trains the LatentManipulator model."""
if device is None:
if torch.cuda.is_available(): device = torch.device("cuda")
elif torch.backends.mps.is_available(): device = torch.device("mps")
else: device = torch.device("cpu")
print(f"Training on device: {device}")
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.01)
criterion = nn.MSELoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True) # More aggressive reduction factor
os.makedirs(checkpoint_dir, exist_ok=True)
best_loss = float('inf')
print(f"Starting training for {epochs} epochs...")
for epoch in range(epochs):
model.train() # Set model to training mode
running_loss = 0.0
# Learning Rate Warmup
if epoch < warmup_epochs:
lr = base_lr * (epoch + 1) / warmup_epochs
for param_group in optimizer.param_groups:
param_group['lr'] = lr
current_lr = lr
else:
# Get current LR from the optimizer after warmup potentially adjusted by scheduler
current_lr = optimizer.param_groups[0]['lr']
print(f"\n--- Epoch {epoch+1}/{epochs} --- LR: {current_lr:.6f}")
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} Training", leave=True)
for batch_idx, (q_emb, a_emb) in enumerate(pbar):
q_emb, a_emb = q_emb.to(device), a_emb.to(device)
optimizer.zero_grad()
outputs = model(q_emb)
loss = criterion(outputs, a_emb)
# Check for NaN loss
if torch.isnan(loss):
print(f"NaN loss detected at Epoch {epoch+1}, Batch {batch_idx}. Stopping training.")
# Optionally save state before exiting
save_checkpoint({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'avg_loss': float('inf') }, os.path.join(checkpoint_dir, "checkpoint_error_nan.pt"))
return # Stop training
loss.backward()
# Gradient Clipping
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), clip_value)
optimizer.step()
running_loss += loss.item()
# Update progress bar
if (batch_idx + 1) % 100 == 0: # Update less frequently
pbar.set_postfix_str(f"Loss: {loss.item():.4f}, GradNorm: {grad_norm:.4f}")
avg_loss = running_loss / len(dataloader)
print(f"Epoch {epoch+1} Average Loss: {avg_loss:.6f}")
# Step the scheduler based on the average loss for the epoch
scheduler.step(avg_loss)
# Save checkpoint if loss improved
if avg_loss < best_loss:
print(f"Loss improved from {best_loss:.6f} to {avg_loss:.6f}. Saving checkpoint...")
best_loss = avg_loss
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}_best.pt")
save_checkpoint({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'avg_loss': avg_loss
}, filename=checkpoint_path)
else:
print(f"Loss did not improve from {best_loss:.6f}.")
# Optional: Save checkpoint every epoch regardless
# checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pt")
# save_checkpoint({ ... }, filename=checkpoint_path)
print("\nTraining complete.")
# --- Main Training Execution Block ---
if __name__ == "__main__":
npy_file = "/path/to/save/embeddings.npy" # Make sure this path is correct
checkpoint_dir = "latent_manipulator_checkpoints" # Directory to save model checkpoints
try:
dataset = NPYEmbeddingDataset(npy_file)
# Adjust batch_size and num_workers based on your system
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True if device=='cuda' else False)
model = LatentManipulator(dropout_rate=0.2) # Instantiate the model
model.apply(init_weights) # Initialize weights
print("Model Architecture:\n", model)
total_params = count_parameters(model)
print(f"\nTotal Trainable Parameters: {total_params:,}")
# Start training
train(model, dataloader, epochs=10, base_lr=1e-4, clip_value=5.0, device=device, checkpoint_dir=checkpoint_dir)
except FileNotFoundError:
print(f"Error: Embeddings file not found at {npy_file}. Please generate embeddings first.")
except Exception as e:
print(f"An error occurred during training setup or execution: {e}")
```
### 5. Inference: Using the Trained Model
To get an answer for a new question:
1. Load the trained `LatentManipulator` model from a checkpoint.
2. Load the `BottleneckT5Autoencoder` (needed for encoding the question and decoding the answer).
3. Encode the input question text into its latent vector using the autoencoder.
4. Pass this latent vector through the loaded `LatentManipulator` to get the predicted latent vector for the answer.
5. Decode this answer latent vector back into text using the autoencoder.
```python
# --- Inference Script ---
# Make sure latent_manipulator.py (containing the model definition) is accessible
# and you have the BottleneckT5Autoencoder class defined/imported
# --- Function to load the Latent Manipulator model ---
def load_manipulator(checkpoint_path, device):
"""Loads the trained LatentManipulator from a checkpoint file."""
# Instantiate the model architecture (ensure dropout_rate matches training)
model = LatentManipulator(dropout_rate=0.2)
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval() # Set to evaluation mode
print(f"Loaded LatentManipulator from epoch {checkpoint.get('epoch', 'N/A')} with loss {checkpoint.get('avg_loss', 'N/A'):.6f}")
return model
except FileNotFoundError:
print(f"Error: Checkpoint file not found at {checkpoint_path}")
raise
except Exception as e:
print(f"Error loading checkpoint: {e}")
raise
# --- Main Inference Execution ---
if __name__ == "__main__":
# --- Configuration ---
autoencoder_model_path = 'thesephist/contra-bottleneck-t5-large-wikipedia'
manipulator_checkpoint_path = "latent_manipulator_checkpoints/checkpoint_epoch_10_best.pt" # Path to your best saved checkpoint
# Determine device
if torch.cuda.is_available(): device = 'cuda'
elif torch.backends.mps.is_available(): device = 'mps'
else: device = 'cpu'
print(f"Using device for inference: {device}")
# --- Load Models ---
try:
# Load the Autoencoder (needed for embed/generate)
autoencoder = BottleneckT5Autoencoder(model_path=autoencoder_model_path, device=device)
# Load the trained Latent Manipulator
manipulator_model = load_manipulator(manipulator_checkpoint_path, device)
except Exception as e:
print(f"Failed to load models: {e}")
exit()
# --- Get Input and Generate ---
while True:
try:
input_text = input("Enter your question (or type 'quit' to exit): ")
if input_text.lower() == 'quit':
break
if not input_text:
continue
# 1. Encode the input question
input_embedding_latent = autoencoder.embed(input_text) # Shape [1, 1024]
# Ensure it's on the correct device (embed should handle this, but double-check)
input_embedding_latent = input_embedding_latent.to(device)
# 2. Manipulate the latent vector to get the answer latent
with torch.no_grad():
output_embedding_latent = manipulator_model(input_embedding_latent) # Shape [1, 1024]
# 3. Decode the answer latent back to text
# Ensure the latent vector is detached and on CPU if generate_from_latent expects it,
# but the provided class seems to handle device transfer internally.
output_text = autoencoder.generate_from_latent(output_embedding_latent, temperature=0.5) # Adjust temperature as needed
print("\nInput: ", input_text)
print("Output: ", output_text)
print("-" * 30)
except KeyboardInterrupt:
print("\nExiting.")
break
except Exception as e:
print(f"An error occurred during generation: {e}")
```
## Resources
* **Autoencoder Model:** [thesephist/contra-bottleneck-t5-large-wikipedia](https://huggingface.co/thesephist/contra-bottleneck-t5-large-wikipedia) on Hugging Face.
* **Latent Manipulator Checkpoint & Data:** [Gal-Lahat/LatentManipulator-checkpoint_epoch_10](https://huggingface.co/Gal-Lahat/LatentManipulator-checkpoint_epoch_10.pt
* ) on Hugging Face (includes the trained manipulator checkpoint and the generated `embeddings.npy`).
* **Raw Training Data (Example):** [MBZUAI/LaMini-instruction](https://huggingface.co/datasets/MBZUAI/LaMini-instruction) dataset.
## Conclusion
The Latent Manipulator presents an intriguing alternative to standard sequential text generation. By separating the "thinking" (latent space transformation) from the "speaking" (text decoding), it opens up possibilities for potentially more efficient, controllable, and perhaps even language-agnostic reasoning in AI models. While still experimental, this approach highlights the ongoing exploration into different ways AI can process and generate information.
This article was compiled with the help of an LLM.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment