Created
April 9, 2025 16:33
-
-
Save Gal-Lahat/53ff2751ccd82329ae5f5189ffb4e525 to your computer and use it in GitHub Desktop.
Latent Manipulator Cookbook.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This guide explains the "Latent Manipulator," an experimental AI architecture designed to "think" in a latent space before generating text, contrasting with standard Transformer models that predict text sequentially. It includes the theory, code for implementation, and links to resources. | |
Based on the video exploring this concept: [https://www.youtube.com/watch?v=fWiieyG2zes] | |
## Why a Latent Manipulator? The Theory | |
Standard Large Language Models (LLMs) like ChatGPT are typically based on the **Transformer** architecture. Their core operation involves predicting the *next word* (or token) in a sequence, given all the preceding words. This means their process of "thinking" or reasoning is intertwined with the act of generating text word-by-word. If you ask ChatGPT, "Can you think quietly before writing?", it might say yes, but architecturally, it *can't* – its computation *is* the generation process. | |
Humans, however, can often form an "idea" or grasp the semantics of a concept before finding the exact words to express it. The Latent Manipulator architecture attempts to mimic this separation: | |
1. **Idea Space (Latent Space):** We need a way to represent the *meaning* or *idea* of a piece of text numerically, separate from the text itself. This is achieved using an **Autoencoder**. | |
* **Encoder:** Takes text as input and compresses it into a dense numerical vector (e.g., 1024 numbers). This vector lives in the "latent space" and represents the "idea" of the input text. | |
* **Decoder:** Takes a vector from the latent space and reconstructs the original text (or a close approximation). | |
* **Bottleneck:** The crucial part is the "bottleneck" in the middle of the autoencoder (the latent space itself), which forces the model to learn a compact, meaningful representation of the input. | |
2. **The Latent Manipulator (Thinking Engine):** This is a separate model (which doesn't *have* to be a Transformer) that operates *entirely within the latent space*. | |
* It takes the latent vector representing the *question* (generated by the Encoder). | |
* It performs computations on this vector to transform it into a *new* latent vector representing the *answer*. | |
* This transformation is the "thinking" process, happening *without generating any text*. | |
3. **Generating the Answer:** The resulting latent vector (the "idea" of the answer) is then fed into the **Decoder** part of the autoencoder, which converts this "idea" back into human-readable text. | |
**In essence: Text Question -> Encoder -> Latent Question -> Latent Manipulator -> Latent Answer -> Decoder -> Text Answer.** | |
This separation offers potential advantages: | |
* **True "Thinking":** Allows computation on semantic meaning before articulation. | |
* **Multilingual Potential:** The latent space could potentially become language-agnostic. You could train different Encoder/Decoder pairs for various languages but use the *same* Latent Manipulator for reasoning, promoting consistency across languages. | |
* **Efficiency & Control:** Manipulating smaller latent vectors might be more efficient than full text generation for certain reasoning tasks. | |
## Implementation Guide | |
### Prerequisites | |
* Python 3.12.2 (or compatible) | |
* Transformers library (version 4.37.2 used here) | |
* PyTorch | |
* Pandas & PyArrow (for data preparation) | |
* NumPy | |
* tqdm (for progress bars) | |
```python | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
import json | |
import numpy as np | |
from tqdm import tqdm | |
import pandas as pd | |
import glob | |
import pyarrow | |
``` | |
### 1. Setup: Loading the Autoencoder | |
We use a pre-trained T5 model that has been adapted into a bottlenecked autoencoder. This model provides the `embed` (text to latent) and `generate_from_latent` (latent to text) functionalities. | |
```python | |
# Define the Autoencoder Abstraction (Helper Class) | |
# Note: This requires the specific model code from 'thesephist/contra-bottleneck-t5-large-wikipedia' | |
# Ensure you have 'trust_remote_code=True' when loading if needed. | |
class BottleneckT5Autoencoder: | |
def __init__(self, model_path: str, device='cpu'): | |
self.device = device | |
print(f"Using device: {self.device}") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512) | |
# Ensure trust_remote_code=True if the model requires custom code | |
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device) | |
self.model.eval() # Set to evaluation mode | |
@torch.no_grad() | |
def embed(self, text: str) -> torch.FloatTensor: | |
"""Encodes a single string into a latent embedding.""" | |
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512).to(self.device) | |
# Decoder input starts with the beginning-of-sequence token for T5-like models | |
decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], dtype=torch.long).to(self.device) | |
# Generate the latent embedding | |
outputs = self.model( | |
input_ids=inputs['input_ids'], | |
attention_mask=inputs['attention_mask'], | |
decoder_input_ids=decoder_input_ids, # Provide initial decoder input | |
encode_only=True, # Flag to get the bottleneck representation | |
) | |
# The exact output structure might vary; inspect 'outputs' if needed. | |
# Assuming the latent vector is the first element. | |
return outputs[0] | |
@torch.no_grad() | |
def embed_batch(self, texts: list[str]) -> torch.FloatTensor: | |
"""Encodes a batch of strings into latent embeddings.""" | |
inputs = self.tokenizer( | |
texts, | |
return_tensors='pt', | |
padding=True, | |
truncation=True, | |
max_length=512 | |
).to(self.device) | |
# Prepare decoder start tokens for the batch | |
decoder_start_token_id = self.model.config.decoder_start_token_id | |
if decoder_start_token_id is None: | |
decoder_start_token_id = self.tokenizer.pad_token_id # Fallback if not defined | |
decoder_input_ids = torch.full( | |
(len(texts), 1), | |
decoder_start_token_id, | |
dtype=torch.long, | |
device=self.device | |
) | |
outputs = self.model( | |
**inputs, | |
decoder_input_ids=decoder_input_ids, | |
encode_only=True, | |
) | |
return outputs[0] | |
@torch.no_grad() | |
def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=0.4) -> str: | |
"""Decodes a latent embedding back into text.""" | |
# Ensure latent is on the correct device and has a batch dimension | |
if latent.dim() == 1: | |
latent = latent.unsqueeze(0) | |
latent = latent.to(self.device) | |
# Use the model's generate method with the latent vector | |
# This relies on the custom model code handling the 'latent_vector' parameter | |
output_sequences = self.model.generate( | |
encoder_outputs=None, # We provide latent directly, not standard encoder outputs | |
latent_vector=latent, # Custom argument for this specific model | |
max_length=max_length, | |
do_sample=True, | |
temperature=temperature, | |
top_p=0.9, | |
num_return_sequences=1, | |
pad_token_id=self.tokenizer.eos_token_id # Important for stopping generation | |
) | |
# Decode the first sequence | |
return self.tokenizer.decode(output_sequences[0], skip_special_tokens=True) | |
``` | |
```python | |
# --- Initialize the Autoencoder --- | |
# Set your Hugging Face token if needed | |
# os.environ["HF_TOKEN"] = "your_huggingface_token" | |
# Determine device (adjust as needed: 'cuda', 'mps', 'cpu') | |
if torch.cuda.is_available(): | |
device = 'cuda' | |
elif torch.backends.mps.is_available(): # For Apple Silicon | |
device = 'mps' | |
else: | |
device = 'cpu' | |
# Load the pre-trained autoencoder model | |
model_path = 'thesephist/contra-bottleneck-t5-large-wikipedia' | |
autoencoder = BottleneckT5Autoencoder(model_path=model_path, device=device) | |
print("Autoencoder loaded successfully.") | |
``` | |
### 2. Data Preparation | |
The training data for the Latent Manipulator consists of pairs of latent embeddings: one for the instruction (question) and one for the response (answer). | |
#### a) Combining Raw Data (if needed) | |
The example uses the LaMini dataset, provided as Parquet files. If your data is similar, you first need to combine these into a single file (e.g., JSONL). | |
```python | |
# --- Combine Parquet files into a single JSONL (Example) --- | |
parquet_dir = "/path/to/your/parquet/files/" # Directory containing the .parquet files | |
output_jsonl_file = "/path/to/save/merged_output.jsonl" | |
parquet_files = glob.glob(os.path.join(parquet_dir, "train-*.parquet")) | |
print(f"Found Parquet files: {parquet_files}") | |
# Use 'append' mode to allow resuming if interrupted | |
with open(output_jsonl_file, "a", encoding="utf-8") as outfile: | |
for file_path in tqdm(parquet_files, desc="Processing Parquet files"): | |
try: | |
df = pd.read_parquet(file_path) | |
print(f"Read {len(df)} rows from {os.path.basename(file_path)}") | |
# Iterate through DataFrame rows and write as JSON lines | |
for _, row in tqdm(df.iterrows(), total=len(df), desc="Writing JSONL", leave=False): | |
# Ensure columns 'instruction' and 'response' exist | |
if 'instruction' in row and 'response' in row: | |
json_record = json.dumps({"instruction": row['instruction'], "response": row['response']}) | |
outfile.write(json_record + '\n') | |
else: | |
print(f"Skipping row due to missing columns: {row.to_dict()}") | |
except Exception as e: | |
print(f"Error processing {file_path}: {e}") | |
print(f"Finished merging Parquet files into {output_jsonl_file}") | |
``` | |
#### b) Generating Embeddings | |
Convert the text instruction/response pairs from the JSONL file into latent embeddings using the autoencoder and save them to a NumPy file. This uses checkpointing to handle large datasets and allow resuming. | |
```python | |
# --- Generate Embeddings from JSONL --- | |
# Configuration | |
checkpoint_interval = 100_000 # Save progress every N lines | |
input_jsonl_file = "/path/to/save/merged_output.jsonl" # Input JSONL file from previous step | |
embeddings_file = "/path/to/save/embeddings.npy" # Output NumPy file | |
checkpoint_file = "/path/to/save/checkpoint.txt" # Checkpoint file | |
batch_size = 64 # Adjust based on your GPU memory | |
# --- Function to process batches --- | |
def process_batch(batch_texts_instructions, batch_texts_responses, autoencoder_model): | |
if not batch_texts_instructions or not batch_texts_responses: | |
return [], [] | |
try: | |
# Ensure the model's embed_batch method handles lists of texts | |
instr_embeddings = autoencoder_model.embed_batch(batch_texts_instructions) | |
resp_embeddings = autoencoder_model.embed_batch(batch_texts_responses) | |
# Move embeddings to CPU before converting to NumPy | |
return instr_embeddings.cpu().numpy(), resp_embeddings.cpu().numpy() | |
except Exception as e: | |
print(f"Error processing batch: {e}") | |
# Optionally, try processing item by item as a fallback | |
instr_np, resp_np = [], [] | |
for instr, resp in zip(batch_texts_instructions, batch_texts_responses): | |
try: | |
instr_emb = autoencoder_model.embed(instr).cpu().numpy() | |
resp_emb = autoencoder_model.embed(resp).cpu().numpy() | |
instr_np.append(instr_emb) | |
resp_np.append(resp_emb) | |
except Exception as item_e: | |
print(f"Error processing item '{instr[:50]}...': {item_e}") | |
return np.array(instr_np) if instr_np else [], np.array(resp_np) if resp_np else [] | |
# --- Main Embedding Generation Logic --- | |
try: | |
with open(input_jsonl_file, "r", encoding="utf-8") as f: | |
total_lines = sum(1 for _ in f) | |
print(f"Total lines in file: {total_lines}") | |
except FileNotFoundError: | |
print(f"Error: Input JSONL file not found at {input_jsonl_file}") | |
exit() | |
if os.path.exists(checkpoint_file): | |
with open(checkpoint_file, "r") as f: | |
last_processed_line = int(f.read().strip()) | |
print(f"Resuming from line: {last_processed_line + 1}") | |
else: | |
last_processed_line = 0 | |
print("No checkpoint found. Starting from the beginning.") | |
if os.path.exists(embeddings_file) and last_processed_line > 0: | |
# Load only if resuming and file exists | |
try: | |
existing_embeddings = np.load(embeddings_file) | |
# Ensure we only load embeddings corresponding to processed lines | |
num_expected_embeddings = last_processed_line * 2 | |
if existing_embeddings.shape[0] >= num_expected_embeddings: | |
embeddings_list = list(existing_embeddings[:num_expected_embeddings]) | |
print(f"Loaded {len(embeddings_list)} embeddings from previous run (up to line {last_processed_line}).") | |
else: | |
print("Warning: Embedding file size doesn't match checkpoint. Starting embeddings list fresh.") | |
embeddings_list = [] | |
last_processed_line = 0 # Reset checkpoint if mismatch | |
except Exception as e: | |
print(f"Error loading existing embeddings: {e}. Starting fresh.") | |
embeddings_list = [] | |
last_processed_line = 0 # Reset checkpoint on load error | |
else: | |
embeddings_list = [] | |
if last_processed_line > 0: | |
print("Warning: Checkpoint found but no embeddings file. Resetting checkpoint.") | |
last_processed_line = 0 # Reset checkpoint if no embedding file | |
overall_line_count = last_processed_line | |
batch_instructions: list[str] = [] | |
batch_responses: list[str] = [] | |
try: | |
with open(input_jsonl_file, "r", encoding="utf-8") as f: | |
# Skip lines up to the checkpoint | |
for _ in range(last_processed_line): | |
next(f) | |
# Process remaining lines with tqdm | |
pbar = tqdm(f, total=total_lines, initial=last_processed_line, unit="line", desc="Processing lines") | |
for line in pbar: | |
overall_line_count += 1 | |
try: | |
obj = json.loads(line) | |
if "instruction" in obj and "response" in obj: | |
batch_instructions.append(obj["instruction"]) | |
batch_responses.append(obj["response"]) | |
else: | |
print(f"Skipping line {overall_line_count}: Missing 'instruction' or 'response'.") | |
continue # Skip this line | |
# Process batch when full | |
if len(batch_instructions) == batch_size: | |
instr_np, resp_np = process_batch(batch_instructions, batch_responses, autoencoder) | |
# Add pairs to the list | |
for i in range(len(instr_np)): | |
embeddings_list.append(instr_np[i]) | |
embeddings_list.append(resp_np[i]) | |
# Clear batches | |
batch_instructions = [] | |
batch_responses = [] | |
# Checkpoint saving logic | |
if overall_line_count % checkpoint_interval == 0: | |
current_line_to_save = overall_line_count | |
with open(checkpoint_file, "w") as cf: | |
cf.write(str(current_line_to_save)) | |
np.save(embeddings_file, np.array(embeddings_list, dtype=np.float32)) | |
pbar.set_postfix_str(f"Checkpoint saved at line {current_line_to_save}") | |
except json.JSONDecodeError: | |
print(f"Skipping line {overall_line_count}: Invalid JSON.") | |
except Exception as e: | |
print(f"Error on line {overall_line_count}: {e}") | |
# Process any remaining items in the last batch | |
if batch_instructions: | |
instr_np, resp_np = process_batch(batch_instructions, batch_responses, autoencoder) | |
for i in range(len(instr_np)): | |
embeddings_list.append(instr_np[i]) | |
embeddings_list.append(resp_np[i]) | |
# Final save | |
with open(checkpoint_file, "w") as cf: | |
cf.write(str(overall_line_count)) | |
np.save(embeddings_file, np.array(embeddings_list, dtype=np.float32)) | |
print(f"Processing complete. Final line count: {overall_line_count}/{total_lines}. Embeddings saved.") | |
except FileNotFoundError: | |
print(f"Error: Input JSONL file not found at {input_jsonl_file}") | |
except Exception as e: | |
print(f"An unexpected error occurred: {e}") | |
# Save progress even on error | |
with open(checkpoint_file, "w") as cf: | |
cf.write(str(overall_line_count - len(batch_instructions))) # Save last fully processed line | |
if embeddings_list: | |
np.save(embeddings_file, np.array(embeddings_list, dtype=np.float32)) | |
print("Saved progress before exiting due to error.") | |
``` | |
### 3. Latent Manipulator Model Code | |
This defines the neural network that learns to map question embeddings to answer embeddings. The architecture shown uses multiple feed-forward layers with skip-connection-like features (concatenating intermediate "choked" outputs). | |
* **Input:** 1024-dimensional latent vector (from Encoder). | |
* **Architecture:** A series of Linear layers, BatchNorm, LeakyReLU, and Dropout. Intermediate outputs are "choked" (reduced) to 2048 dimensions and concatenated with the original input before final aggregation layers. This complex structure aims to handle deep transformations while mitigating vanishing/exploding gradients. | |
* **Output:** 1024-dimensional latent vector (to be fed to Decoder). | |
*(Code provided by the user, slightly adjusted for clarity and comments)* | |
```python | |
# --- Latent Manipulator Model Definition --- | |
class LatentManipulator(nn.Module): | |
""" | |
A Feed-Forward Network designed to manipulate latent embeddings. | |
Takes a 1024-dim embedding and outputs a 1024-dim embedding. | |
Uses intermediate layer outputs (choked) and concatenation for richness. | |
""" | |
def __init__(self, dropout_rate=0.2): # Reduced dropout from original | |
super(LatentManipulator, self).__init__() | |
# --- Main Layers (Expand -> Contract) --- | |
self.layer1 = nn.Sequential(nn.Linear(1024, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.layer2 = nn.Sequential(nn.Linear(2048, 4096), nn.BatchNorm1d(4096), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.layer3 = nn.Sequential(nn.Linear(4096, 6144), nn.BatchNorm1d(6144), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.layer4 = nn.Sequential(nn.Linear(6144, 9216), nn.BatchNorm1d(9216), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) # Widest layer | |
self.layer5 = nn.Sequential(nn.Linear(9216, 6144), nn.BatchNorm1d(6144), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.layer6 = nn.Sequential(nn.Linear(6144, 4096), nn.BatchNorm1d(4096), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.layer7 = nn.Sequential(nn.Linear(4096, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
# --- Choke Layers (Reduce intermediate outputs to 2048) --- | |
# These act like shortcuts, bringing information from earlier layers forward. | |
self.choke1 = nn.Sequential(nn.Linear(2048, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.choke2 = nn.Sequential(nn.Linear(4096, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.choke3 = nn.Sequential(nn.Linear(6144, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.choke4 = nn.Sequential(nn.Linear(9216, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.choke5 = nn.Sequential(nn.Linear(6144, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.choke6 = nn.Sequential(nn.Linear(4096, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.choke7 = nn.Sequential(nn.Linear(2048, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
# --- Aggregation Layers (Combine concatenated features) --- | |
# Total input size = 1024 (original input) + 7 * 2048 (choked outputs) = 15360 | |
self.aLayer1 = nn.Sequential(nn.Linear(15360, 8192), nn.BatchNorm1d(8192), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
self.aLayer2 = nn.Sequential(nn.Linear(8192, 4096), nn.BatchNorm1d(4096), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) | |
# --- Final Output Layer --- | |
self.output_layer = nn.Linear(4096, 1024) # Output matches input dimension | |
def forward(self, x): | |
# Pass through main layers | |
x1 = self.layer1(x); x2 = self.layer2(x1); x3 = self.layer3(x2) | |
x4 = self.layer4(x3); x5 = self.layer5(x4); x6 = self.layer6(x5) | |
x7 = self.layer7(x6) | |
# Apply choke layers | |
c1 = self.choke1(x1); c2 = self.choke2(x2); c3 = self.choke3(x3) | |
c4 = self.choke4(x4); c5 = self.choke5(x5); c6 = self.choke6(x6) | |
c7 = self.choke7(x7) | |
# Concatenate original input and all choked outputs | |
concat = torch.cat([x, c1, c2, c3, c4, c5, c6, c7], dim=1) # Dim 1 for batch processing | |
# Pass through aggregation layers | |
out = self.aLayer1(concat) | |
out = self.aLayer2(out) | |
out = self.output_layer(out) | |
return out | |
# --- Helper: Weight Initialization --- | |
def init_weights(m): | |
"""Applies Kaiming Normal initialization for LeakyReLU.""" | |
if isinstance(m, nn.Linear): | |
nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu') | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
# --- Helper: Count Parameters --- | |
def count_parameters(model): | |
"""Counts the number of trainable parameters in a model.""" | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
``` | |
### 4. Training the Latent Manipulator | |
This involves setting up a dataset to efficiently load the pre-computed embeddings and a training loop. | |
#### a) Dataset Class | |
Uses NumPy's memory mapping (`mmap_mode='r'`) to avoid loading the entire (potentially huge) embeddings file into RAM. It loads only the required question-answer pair for each `__getitem__` call. | |
```python | |
# --- NumPy Embedding Dataset --- | |
class NPYEmbeddingDataset(Dataset): | |
""" | |
Lazily loads pairs of embeddings (instruction, response) from a NumPy file | |
using memory mapping for efficiency with large files. | |
Assumes embeddings are stored sequentially: [instr1, resp1, instr2, resp2, ...]. | |
""" | |
def __init__(self, npy_file): | |
self.npy_file = npy_file | |
# Load with mmap_mode to get shape/dtype without loading all data | |
try: | |
with np.load(npy_file, mmap_mode='r') as data: | |
self.shape = data.shape | |
self.dtype = data.dtype | |
except FileNotFoundError: | |
print(f"Error: NPY embedding file not found at {npy_file}") | |
raise | |
except Exception as e: | |
print(f"Error loading NPY file: {e}") | |
raise | |
if len(self.shape) != 2 or self.shape[0] % 2 != 0: | |
raise ValueError(f"Expected a 2D numpy array with an even number of rows (embeddings). Got shape: {self.shape}") | |
self.num_pairs = self.shape[0] // 2 | |
self.embedding_dim = self.shape[1] | |
print(f"Dataset initialized: {self.num_pairs} pairs, embedding dim {self.embedding_dim}") | |
def __len__(self): | |
"""Returns the number of instruction-response pairs.""" | |
return self.num_pairs | |
def __getitem__(self, idx): | |
"""Loads the idx-th instruction and response embedding pair.""" | |
if idx >= self.num_pairs: | |
raise IndexError("Index out of bounds") | |
# Load with mmap_mode again inside __getitem__ for multi-process loading safety | |
# This ensures each worker gets its own file handle if num_workers > 0 | |
data = np.load(self.npy_file, mmap_mode='r') | |
# Calculate row indices for the pair | |
q_idx = idx * 2 | |
a_idx = idx * 2 + 1 | |
# Extract the embeddings and convert to tensors | |
q_emb = torch.from_numpy(data[q_idx].copy()).float() # Use .copy() with mmap | |
a_emb = torch.from_numpy(data[a_idx].copy()).float() | |
return q_emb, a_emb | |
``` | |
#### b) Training Loop | |
Standard PyTorch training loop using the defined dataset and model. Key features: | |
* **Loss:** Mean Squared Error (`MSELoss`) because we are comparing output embeddings to target embeddings (regression). | |
* **Optimizer:** AdamW is a good default choice. | |
* **Learning Rate Scheduling:** `ReduceLROnPlateau` adjusts the learning rate based on validation loss (or average training loss here) stagnation. A *warmup* phase is also added to start with a lower LR and gradually increase it, improving stability early in training. | |
* **Gradient Clipping:** Prevents exploding gradients, crucial for deep networks. | |
* **Checkpointing:** Saves the model state periodically, especially when the loss improves. | |
```python | |
# --- Utility to save checkpoints --- | |
def save_checkpoint(state, filename="checkpoint.pt"): | |
"""Saves model and optimizer state.""" | |
try: | |
torch.save(state, filename) | |
print(f"Checkpoint saved to {filename}") | |
except Exception as e: | |
print(f"Error saving checkpoint: {e}") | |
# --- Training Function --- | |
def train(model, dataloader, epochs=10, base_lr=1e-4, warmup_epochs=1, clip_value=5.0, device=None, checkpoint_dir="checkpoints"): | |
"""Trains the LatentManipulator model.""" | |
if device is None: | |
if torch.cuda.is_available(): device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): device = torch.device("mps") | |
else: device = torch.device("cpu") | |
print(f"Training on device: {device}") | |
model.to(device) | |
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.01) | |
criterion = nn.MSELoss() | |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True) # More aggressive reduction factor | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
best_loss = float('inf') | |
print(f"Starting training for {epochs} epochs...") | |
for epoch in range(epochs): | |
model.train() # Set model to training mode | |
running_loss = 0.0 | |
# Learning Rate Warmup | |
if epoch < warmup_epochs: | |
lr = base_lr * (epoch + 1) / warmup_epochs | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
current_lr = lr | |
else: | |
# Get current LR from the optimizer after warmup potentially adjusted by scheduler | |
current_lr = optimizer.param_groups[0]['lr'] | |
print(f"\n--- Epoch {epoch+1}/{epochs} --- LR: {current_lr:.6f}") | |
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} Training", leave=True) | |
for batch_idx, (q_emb, a_emb) in enumerate(pbar): | |
q_emb, a_emb = q_emb.to(device), a_emb.to(device) | |
optimizer.zero_grad() | |
outputs = model(q_emb) | |
loss = criterion(outputs, a_emb) | |
# Check for NaN loss | |
if torch.isnan(loss): | |
print(f"NaN loss detected at Epoch {epoch+1}, Batch {batch_idx}. Stopping training.") | |
# Optionally save state before exiting | |
save_checkpoint({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'avg_loss': float('inf') }, os.path.join(checkpoint_dir, "checkpoint_error_nan.pt")) | |
return # Stop training | |
loss.backward() | |
# Gradient Clipping | |
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), clip_value) | |
optimizer.step() | |
running_loss += loss.item() | |
# Update progress bar | |
if (batch_idx + 1) % 100 == 0: # Update less frequently | |
pbar.set_postfix_str(f"Loss: {loss.item():.4f}, GradNorm: {grad_norm:.4f}") | |
avg_loss = running_loss / len(dataloader) | |
print(f"Epoch {epoch+1} Average Loss: {avg_loss:.6f}") | |
# Step the scheduler based on the average loss for the epoch | |
scheduler.step(avg_loss) | |
# Save checkpoint if loss improved | |
if avg_loss < best_loss: | |
print(f"Loss improved from {best_loss:.6f} to {avg_loss:.6f}. Saving checkpoint...") | |
best_loss = avg_loss | |
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}_best.pt") | |
save_checkpoint({ | |
'epoch': epoch + 1, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'avg_loss': avg_loss | |
}, filename=checkpoint_path) | |
else: | |
print(f"Loss did not improve from {best_loss:.6f}.") | |
# Optional: Save checkpoint every epoch regardless | |
# checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pt") | |
# save_checkpoint({ ... }, filename=checkpoint_path) | |
print("\nTraining complete.") | |
# --- Main Training Execution Block --- | |
if __name__ == "__main__": | |
npy_file = "/path/to/save/embeddings.npy" # Make sure this path is correct | |
checkpoint_dir = "latent_manipulator_checkpoints" # Directory to save model checkpoints | |
try: | |
dataset = NPYEmbeddingDataset(npy_file) | |
# Adjust batch_size and num_workers based on your system | |
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True if device=='cuda' else False) | |
model = LatentManipulator(dropout_rate=0.2) # Instantiate the model | |
model.apply(init_weights) # Initialize weights | |
print("Model Architecture:\n", model) | |
total_params = count_parameters(model) | |
print(f"\nTotal Trainable Parameters: {total_params:,}") | |
# Start training | |
train(model, dataloader, epochs=10, base_lr=1e-4, clip_value=5.0, device=device, checkpoint_dir=checkpoint_dir) | |
except FileNotFoundError: | |
print(f"Error: Embeddings file not found at {npy_file}. Please generate embeddings first.") | |
except Exception as e: | |
print(f"An error occurred during training setup or execution: {e}") | |
``` | |
### 5. Inference: Using the Trained Model | |
To get an answer for a new question: | |
1. Load the trained `LatentManipulator` model from a checkpoint. | |
2. Load the `BottleneckT5Autoencoder` (needed for encoding the question and decoding the answer). | |
3. Encode the input question text into its latent vector using the autoencoder. | |
4. Pass this latent vector through the loaded `LatentManipulator` to get the predicted latent vector for the answer. | |
5. Decode this answer latent vector back into text using the autoencoder. | |
```python | |
# --- Inference Script --- | |
# Make sure latent_manipulator.py (containing the model definition) is accessible | |
# and you have the BottleneckT5Autoencoder class defined/imported | |
# --- Function to load the Latent Manipulator model --- | |
def load_manipulator(checkpoint_path, device): | |
"""Loads the trained LatentManipulator from a checkpoint file.""" | |
# Instantiate the model architecture (ensure dropout_rate matches training) | |
model = LatentManipulator(dropout_rate=0.2) | |
try: | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(device) | |
model.eval() # Set to evaluation mode | |
print(f"Loaded LatentManipulator from epoch {checkpoint.get('epoch', 'N/A')} with loss {checkpoint.get('avg_loss', 'N/A'):.6f}") | |
return model | |
except FileNotFoundError: | |
print(f"Error: Checkpoint file not found at {checkpoint_path}") | |
raise | |
except Exception as e: | |
print(f"Error loading checkpoint: {e}") | |
raise | |
# --- Main Inference Execution --- | |
if __name__ == "__main__": | |
# --- Configuration --- | |
autoencoder_model_path = 'thesephist/contra-bottleneck-t5-large-wikipedia' | |
manipulator_checkpoint_path = "latent_manipulator_checkpoints/checkpoint_epoch_10_best.pt" # Path to your best saved checkpoint | |
# Determine device | |
if torch.cuda.is_available(): device = 'cuda' | |
elif torch.backends.mps.is_available(): device = 'mps' | |
else: device = 'cpu' | |
print(f"Using device for inference: {device}") | |
# --- Load Models --- | |
try: | |
# Load the Autoencoder (needed for embed/generate) | |
autoencoder = BottleneckT5Autoencoder(model_path=autoencoder_model_path, device=device) | |
# Load the trained Latent Manipulator | |
manipulator_model = load_manipulator(manipulator_checkpoint_path, device) | |
except Exception as e: | |
print(f"Failed to load models: {e}") | |
exit() | |
# --- Get Input and Generate --- | |
while True: | |
try: | |
input_text = input("Enter your question (or type 'quit' to exit): ") | |
if input_text.lower() == 'quit': | |
break | |
if not input_text: | |
continue | |
# 1. Encode the input question | |
input_embedding_latent = autoencoder.embed(input_text) # Shape [1, 1024] | |
# Ensure it's on the correct device (embed should handle this, but double-check) | |
input_embedding_latent = input_embedding_latent.to(device) | |
# 2. Manipulate the latent vector to get the answer latent | |
with torch.no_grad(): | |
output_embedding_latent = manipulator_model(input_embedding_latent) # Shape [1, 1024] | |
# 3. Decode the answer latent back to text | |
# Ensure the latent vector is detached and on CPU if generate_from_latent expects it, | |
# but the provided class seems to handle device transfer internally. | |
output_text = autoencoder.generate_from_latent(output_embedding_latent, temperature=0.5) # Adjust temperature as needed | |
print("\nInput: ", input_text) | |
print("Output: ", output_text) | |
print("-" * 30) | |
except KeyboardInterrupt: | |
print("\nExiting.") | |
break | |
except Exception as e: | |
print(f"An error occurred during generation: {e}") | |
``` | |
## Resources | |
* **Autoencoder Model:** [thesephist/contra-bottleneck-t5-large-wikipedia](https://huggingface.co/thesephist/contra-bottleneck-t5-large-wikipedia) on Hugging Face. | |
* **Latent Manipulator Checkpoint & Data:** [Gal-Lahat/LatentManipulator-checkpoint_epoch_10](https://huggingface.co/Gal-Lahat/LatentManipulator-checkpoint_epoch_10.pt | |
* ) on Hugging Face (includes the trained manipulator checkpoint and the generated `embeddings.npy`). | |
* **Raw Training Data (Example):** [MBZUAI/LaMini-instruction](https://huggingface.co/datasets/MBZUAI/LaMini-instruction) dataset. | |
## Conclusion | |
The Latent Manipulator presents an intriguing alternative to standard sequential text generation. By separating the "thinking" (latent space transformation) from the "speaking" (text decoding), it opens up possibilities for potentially more efficient, controllable, and perhaps even language-agnostic reasoning in AI models. While still experimental, this approach highlights the ongoing exploration into different ways AI can process and generate information. | |
This article was compiled with the help of an LLM. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment