Skip to content

Instantly share code, notes, and snippets.

@lastforkbender
Created May 21, 2026 15:40
Show Gist options
  • Select an option

  • Save lastforkbender/273dd46ac5168c98c25e854f21cccdb2 to your computer and use it in GitHub Desktop.

Select an option

Save lastforkbender/273dd46ac5168c98c25e854f21cccdb2 to your computer and use it in GitHub Desktop.
Stedman NN
# stedman_improved.py
# Enhanced production Stedman with adaptive polynomial composition
# via cross-recursive dynamics of rotational node strengths
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32
# -------------------------
# Utilities
# -------------------------
def sample_gumbel(shape, device, eps=1e-20):
u = torch.rand(shape, device=device)
return -torch.log(-torch.log(u + eps) + eps)
def gumbel_softmax_sample(logits, temp=1.0):
g = sample_gumbel(logits.shape, logits.device)
y = (logits + g) / temp
return F.softmax(y, dim=-1)
def cayley_orthogonal_from_skew(S, rcond=1e-5, max_iters=3):
"""
S: [..., d, d] skew-symmetric
Cayley: Q = (I - S)^{-1} (I + S)
Enhanced with fallback for near-singular cases.
"""
d = S.size(-1)
I = torch.eye(d, device=S.device, dtype=S.dtype).expand_as(S)
I_minus_S = I - S
I_plus_S = I + S
try:
# Primary: direct solve
X = torch.linalg.solve(I_minus_S, I_plus_S)
except RuntimeError:
# Fallback: iterative refinement or lstsq
X = torch.linalg.lstsq(I_minus_S, I_plus_S).solution
# Optional: iterative polish
for _ in range(max_iters):
residual = I_minus_S @ X - I_plus_S
if residual.abs().max() < 1e-4:
break
delta = torch.linalg.lstsq(I_minus_S, residual).solution
X = X - delta
return X
def make_skew_from_params(P):
"""P: [..., d, d] -> skew S = P - P^T"""
return P - P.transpose(-1, -2)
def compute_spectral_norm_batch(M, n_power_iters=2):
"""
Compute approximate spectral norm (largest singular value) via power iteration.
M: [B, d, d] or [d, d]
Returns: [B] or scalar
"""
if M.dim() == 2:
M = M.unsqueeze(0)
B, d, _ = M.shape
v = torch.randn(B, d, 1, device=M.device, dtype=M.dtype)
v = v / (v.norm(dim=1, keepdim=True) + 1e-8)
for _ in range(n_power_iters):
u = torch.bmm(M, v)
u = u / (u.norm(dim=1, keepdim=True) + 1e-8)
v = torch.bmm(M.transpose(1, 2), u)
v = v / (v.norm(dim=1, keepdim=True) + 1e-8)
sv = (torch.bmm(M, v) * u).sum(dim=(1, 2))
return sv.squeeze() if B == 1 else sv
# -------------------------
# Batched randomized SVD (GPU-friendly)
# -------------------------
def batched_randomized_range(A, rank, n_iter=1):
N, d = A.shape
device = A.device
P = torch.randn(d, rank, device=device, dtype=A.dtype)
Y = A @ P
for _ in range(n_iter):
Z = A.T @ Y
Y = A @ Z
Q, _ = torch.linalg.qr(Z)
return Q[:, :rank]
class BatchedRandSVD(nn.Module):
def __init__(self, rank=64, n_iter=2):
super().__init__()
self.rank = rank
self.n_iter = n_iter
def forward(self, A):
B, d = A.shape
Q = batched_randomized_range(A, self.rank, self.n_iter)
z = A @ Q
sv = z.norm(dim=0) / math.sqrt(max(1, B))
return z, sv, Q
# -------------------------
# Cross-Recursive Polynomial Compositor (CRPC)
# Leverages rotational node strengths for adaptive polynomial expansion
# -------------------------
class CrossRecursivePolyCompositor(nn.Module):
"""
Recursively refines latent codes z using:
1. Node strengths (spectral norms of Cayley orthogonals)
2. Spawn probabilities (which children are active)
3. Gated polynomial expansions per child
This enables "cognitive composition" of high-scoring nodes near rotational strengths
even when spawning is minimal.
"""
def __init__(self, rank, y_dim, k_max, degree=2, hidden=None, recurse_depth=2):
super().__init__()
self.rank = rank
self.y_dim = y_dim
self.k_max = k_max
self.degree = degree
self.recurse_depth = recurse_depth
# Per-child polynomial readouts (depth-wise)
self.child_poly_layers = nn.ModuleList([
nn.Sequential(
nn.Linear(rank, hidden or rank),
nn.GELU(),
nn.Linear(hidden or rank, y_dim)
) for _ in range(k_max)
])
# Node strength weighting (gated)
self.node_strength_gate = nn.Sequential(
nn.Linear(1, 16),
nn.GELU(),
nn.Linear(16, 1),
nn.Sigmoid()
)
# Recursive refinement layers
self.refine_layers = nn.ModuleList([
nn.Sequential(
nn.Linear(rank + y_dim, rank),
nn.LayerNorm(rank),
nn.GELU()
) for _ in range(recurse_depth)
])
# Final composition (adaptive polynomial)
in_feat = y_dim + rank * degree
self.final_compose = nn.Sequential(
nn.Linear(in_feat, hidden or 128),
nn.LayerNorm(hidden or 128),
nn.GELU(),
nn.Linear(hidden or 128, y_dim)
)
def forward(self, z, O, p_on, sv):
"""
z: [B, r] latent codes from randomized SVD
O: [k, d_h, d_h] Cayley orthogonal transforms (children)
p_on: [B, k] spawn probabilities
sv: [r] singular values
Returns: y_refined [B, y_dim], composition_stats dict
"""
B, r = z.shape
k = O.shape[0]
device = z.device
# Compute node strengths: spectral norms of child transforms
# This tells us how "strong" each rotational child is
node_strengths = compute_spectral_norm_batch(O, n_power_iters=2) # [k]
node_strengths = node_strengths / (node_strengths.max() + 1e-8)
# Gated node strengths using spawn probability
gated_strengths = self.node_strength_gate(
node_strengths.unsqueeze(-1)
).squeeze(-1) # [k]
# Weighted combination: active children with high node strengths contribute more
child_weights = p_on * gated_strengths.unsqueeze(0) # [B, k]
child_weights = child_weights / (child_weights.sum(dim=1, keepdim=True) + 1e-8)
# Per-child polynomial outputs
child_outputs = torch.stack([
poly_layer(z) for poly_layer in self.child_poly_layers
], dim=1) # [B, k, y_dim]
# Weighted composition
y_init = (child_weights.unsqueeze(-1) * child_outputs).sum(dim=1) # [B, y_dim]
# Recursive refinement with polynomial features
z_refined = z.clone()
poly_features = []
for depth in range(self.recurse_depth):
# Polynomial expansion of current latent code
poly_terms = [torch.ones(B, 1, device=device, dtype=z.dtype)]
poly_terms.append(z_refined)
if self.degree >= 2:
poly_terms.append(z_refined * z_refined)
if self.degree >= 3:
poly_terms.append(z_refined * z_refined * z_refined)
poly_feat = torch.cat(poly_terms, dim=1) # [B, 1+r+(r if deg>=2)+...]
poly_features.append(poly_feat)
# Refine z using current prediction and features
refine_input = torch.cat([z_refined, y_init], dim=1)
z_refined = self.refine_layers[depth](refine_input)
# Aggregate polynomial features across recursion depth
poly_agg = torch.cat(poly_features, dim=1) # [B, recurse_depth * (1+r+...)]
# Final composition: combine weighted child output with recursive polynomial features
final_input = torch.cat([y_init, poly_agg[:, self.rank:]], dim=1) # skip constant term
y_refined = self.final_compose(final_input)
# Composition statistics
stats = {
"node_strengths": node_strengths,
"gated_strengths": gated_strengths,
"child_weights": child_weights,
"y_init_norm": y_init.norm(dim=1).mean().item(),
"z_refined_norm": z_refined.norm(dim=1).mean().item(),
}
return y_refined, stats
# -------------------------
# Adaptive Spawn Unit v2 (fixed & enhanced)
# -------------------------
class AdaptiveSpawnV2(nn.Module):
def __init__(self, d_in, d_hidden, k_max=8, temp_init=0.5, hard=False):
super().__init__()
self.d_in = d_in
self.d_hidden = d_hidden
self.k_max = k_max
self.temp = nn.Parameter(torch.tensor(float(temp_init)))
self.hard = hard
self.gate_logits = nn.Linear(d_in, k_max * 2)
self.base_proj = nn.Linear(d_in, d_hidden, bias=True)
self.child_params = nn.Parameter(torch.randn(k_max, d_hidden, d_hidden) * 0.01)
self.child_scale = nn.Parameter(torch.ones(k_max, 1))
def forward(self, x):
B = x.shape[0]
logits = self.gate_logits(x).view(B, self.k_max, 2)
gumb = gumbel_softmax_sample(logits, temp=torch.clamp(self.temp, min=1e-3))
p_on = gumb[..., 1]
if self.hard:
hard_mask = (p_on > 0.5).float()
p_on = (hard_mask - p_on).detach() + p_on
base = self.base_proj(x)
P = self.child_params
S = make_skew_from_params(P)
O = cayley_orthogonal_from_skew(S) # [k, d_h, d_h] - robust
children = torch.einsum("bd,kde->bke", base, O)
children = children * self.child_scale.unsqueeze(0)
weights = p_on.unsqueeze(-1)
agg = (weights * children).sum(dim=1)
stats = {
"p_on_mean": p_on.mean(dim=0),
"expected_spawns": p_on.sum(dim=1).mean(),
"p_on": p_on, # Expose for CRPC
"O": O, # Expose orthogonals for CRPC
}
return agg, stats
# -------------------------
# Metacritic (unchanged)
# -------------------------
class MetacriticProd(nn.Module):
def __init__(self, y_dim, sv_dim, spawn_dim, hidden=256):
super().__init__()
in_dim = y_dim + sv_dim + spawn_dim
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.LayerNorm(hidden),
nn.GELU(),
nn.Linear(hidden, hidden//2),
nn.GELU(),
nn.Linear(hidden//2, 1)
)
def forward(self, y_hat, sv, spawn_stats):
B = y_hat.shape[0]
sv_b = sv.unsqueeze(0).expand(B, -1)
p_on_mean = spawn_stats["p_on_mean"].to(y_hat.device)
expected_spawns = spawn_stats["expected_spawns"]
if not isinstance(expected_spawns, torch.Tensor):
expected_spawns = torch.tensor(expected_spawns, device=y_hat.device, dtype=y_hat.dtype)
spawn_vec = torch.cat([p_on_mean, expected_spawns.unsqueeze(0)], dim=0).unsqueeze(0).expand(B, -1)
inp = torch.cat([y_hat, sv_b, spawn_vec], dim=1)
return self.net(inp).squeeze(-1)
# -------------------------
# Full Improved Model
# -------------------------
class StedmanImproved(nn.Module):
def __init__(self, d_in, d_hidden, y_dim, k_max=8, rank=64, poly_deg=2, recurse_depth=2):
super().__init__()
self.backbone = nn.Sequential(
nn.Linear(d_in, d_hidden),
nn.GELU(),
nn.Linear(d_hidden, d_hidden),
nn.LayerNorm(d_hidden), # Added for stability
nn.GELU()
)
self.asu = AdaptiveSpawnV2(d_hidden, d_hidden, k_max=k_max, temp_init=0.8, hard=False)
self.proj = BatchedRandSVD(rank=rank, n_iter=2)
# Cross-recursive polynomial compositor (NEW)
self.crpc = CrossRecursivePolyCompositor(
rank=rank, y_dim=y_dim, k_max=k_max,
degree=poly_deg, hidden=d_hidden//2, recurse_depth=recurse_depth
)
self.critic = MetacriticProd(y_dim=y_dim, sv_dim=rank, spawn_dim=k_max+1, hidden=256)
self.k_max = k_max
def forward(self, x):
feat = self.backbone(x)
spawned, stats = self.asu(feat)
z, sv, Q = self.proj(spawned)
# Use CRPC instead of simple poly readout
y_hat, crpc_stats = self.crpc(z, stats["O"], stats["p_on"], sv)
# Merge stats
stats.update(crpc_stats)
c_hat = self.critic(y_hat, sv, stats)
return y_hat, c_hat, stats
# -------------------------
# Training utilities
# -------------------------
def spectral_penalty(module_list, coeff=1e-3):
pen = 0.0
for m in module_list:
if hasattr(m, "child_params"):
pen = pen + (m.child_params.view(-1)**2).sum()
return coeff * pen
class LambdaScheduler:
"""Adaptive critic weight scheduling"""
def __init__(self, lambda_init=0.5, lambda_max=2.0, warmup_epochs=2, total_epochs=6):
self.lambda_init = lambda_init
self.lambda_max = lambda_max
self.warmup_epochs = warmup_epochs
self.total_epochs = total_epochs
def get_lambda(self, epoch):
if epoch < self.warmup_epochs:
# Linear warmup
return self.lambda_init + (self.lambda_max - self.lambda_init) * epoch / self.warmup_epochs
else:
# Cosine annealing after warmup
progress = (epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
return self.lambda_max * (1 + math.cos(math.pi * progress)) / 2
# -------------------------
# Synthetic data
# -------------------------
def synthetic_data(n=20000, d=32, noise=0.08):
X = torch.randn(n, d)
W = torch.randn(d, 1)
y = X @ W + 0.6 * torch.sin(X.sum(dim=1, keepdim=True) * 0.7) + noise * torch.randn(n, 1)
return X, y
# -------------------------
# Training loop (improved)
# -------------------------
def train_improved():
"""
Production training with:
- Mixed precision (AMP)
- Gradient clipping
- Learning rate scheduling
- Dynamic critic weight (lambda scheduling)
- Enhanced logging
"""
torch.manual_seed(1)
d_in = 32
d_hidden = 256
y_dim = 1
k_max = 8
rank = 64
model = StedmanImproved(
d_in, d_hidden, y_dim,
k_max=k_max, rank=rank, poly_deg=2, recurse_depth=2
).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
opt, T_0=2, T_mult=1.5, eta_min=1e-5
)
scaler = torch.cuda.amp.GradScaler()
X, y = synthetic_data(n=40000, d=d_in)
X, y = X.to(DEVICE), y.to(DEVICE)
bs = 512
epochs = 6
gamma = 0.02 # spawn cost coefficient
lambda_sched = LambdaScheduler(lambda_init=0.5, lambda_max=1.5, warmup_epochs=2, total_epochs=epochs)
# Tracking metrics
epoch_losses = {"task": [], "critic": [], "spawn": [], "reg": [], "total": []}
epoch_metrics = {"critic_mse": [], "exp_spawns": [], "temp": []}
print("=" * 80)
print(f"Training StedmanImproved on {X.shape[0]} samples, d={d_in}, rank={rank}, k_max={k_max}")
print("=" * 80)
for ep in range(epochs):
perm = torch.randperm(X.size(0))
t0 = time.time()
batch_losses = {"task": 0.0, "critic": 0.0, "spawn": 0.0, "reg": 0.0, "total": 0.0}
n_batches = 0
lambda_c = lambda_sched.get_lambda(ep)
for i in range(0, X.size(0), bs):
idx = perm[i:i+bs]
xb, yb = X[idx], y[idx]
with torch.cuda.amp.autocast():
y_hat, c_hat, stats = model(xb)
# Task loss
task_loss = F.mse_loss(y_hat, yb)
# Critic loss (predict absolute error)
abs_err = (y_hat - yb).abs().squeeze(-1)
critic_loss = F.mse_loss(c_hat, abs_err)
# Spawn cost (penalize high expected spawns)
sp_cost = gamma * stats["expected_spawns"]
# Spectral regularization
reg = spectral_penalty([model.asu], coeff=1e-6)
# Total loss
loss = task_loss + lambda_c * critic_loss + sp_cost + reg
opt.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
scaler.step(opt)
scaler.update()
# Accumulate metrics
batch_losses["task"] += task_loss.item() * xb.size(0)
batch_losses["critic"] += critic_loss.item() * xb.size(0)
batch_losses["spawn"] += sp_cost.item() * xb.size(0)
batch_losses["reg"] += reg.item() * xb.size(0)
batch_losses["total"] += loss.item() * xb.size(0)
n_batches += 1
# Normalize by dataset size
n_samples = X.size(0)
for key in batch_losses:
batch_losses[key] /= n_samples
# Update scheduler
scheduler.step()
# Evaluation on full dataset
with torch.no_grad():
y_hat_full, c_hat_full, stats_full = model(X[:2048])
abs_err_full = (y_hat_full - y[:2048]).abs().squeeze(-1)
critic_mse = F.mse_loss(c_hat_full, abs_err_full).item()
# Log metrics
epoch_losses["task"].append(batch_losses["task"])
epoch_losses["critic"].append(batch_losses["critic"])
epoch_losses["spawn"].append(batch_losses["spawn"])
epoch_losses["reg"].append(batch_losses["reg"])
epoch_losses["total"].append(batch_losses["total"])
epoch_metrics["critic_mse"].append(critic_mse)
epoch_metrics["exp_spawns"].append(stats_full["expected_spawns"].item())
epoch_metrics["temp"].append(float(model.asu.temp))
print(
f"Epoch {ep+1:2d}/{epochs} | "
f"Loss {batch_losses['total']:.6f} (task={batch_losses['task']:.4f}, "
f"critic={batch_losses['critic']:.4f}, spawn={batch_losses['spawn']:.4f}) | "
f"λ_c={lambda_c:.2f} | Temp={float(model.asu.temp):.3f} | "
f"Spawns={stats_full['expected_spawns'].item():.2f} | "
f"Time {time.time()-t0:.2f}s"
)
print("\n" + "=" * 80)
print("FINAL EVALUATION")
print("=" * 80)
with torch.no_grad():
y_hat, c_hat, stats = model(X[:4096])
abs_err = (y_hat - y[:4096]).abs().squeeze(-1)
task_mse = F.mse_loss(y_hat, y[:4096]).item()
critic_mse = F.mse_loss(c_hat, abs_err).item()
print(f"Task MSE: {task_mse:.6f}")
print(f"Critic MSE: {critic_mse:.6f}")
print(f"Expected spawns: {stats['expected_spawns'].item():.2f} / {k_max}")
print(f"Spawn probability: {stats['p_on_mean'].mean().item():.4f}")
print(f"Node strengths: min={stats['node_strengths'].min():.4f}, "
f"max={stats['node_strengths'].max():.4f}, "
f"mean={stats['node_strengths'].mean():.4f}")
print(f"Gated strengths: min={stats['gated_strengths'].min():.4f}, "
f"max={stats['gated_strengths'].max():.4f}, "
f"mean={stats['gated_strengths'].mean():.4f}")
print(f"Latent norm (z): {stats['z_refined_norm']:.4f}")
print(f"Output norm (y): {stats['y_init_norm']:.4f}")
return model, epoch_losses, epoch_metrics
# -------------------------
# Analysis utilities
# -------------------------
def analyze_node_activation(model, X, y, device=DEVICE):
"""
Analyze which nodes (children) are actively used and their interaction patterns.
"""
model.eval()
with torch.no_grad():
y_hat, c_hat, stats = model(X[:1024].to(device))
p_on = stats["p_on"] # [B, k]
node_strengths = stats["node_strengths"]
gated_strengths = stats["gated_strengths"]
child_weights = stats["child_weights"]
print("\n" + "=" * 80)
print("NODE ACTIVATION ANALYSIS")
print("=" * 80)
print(f"{'Node':<6} {'P(On)':<8} {'Strength':<10} {'Gated':<8} {'Avg Weight':<12}")
print("-" * 80)
for k in range(model.k_max):
p_on_k = p_on[:, k].mean().item()
strength_k = node_strengths[k].item()
gated_k = gated_strengths[k].item()
weight_k = child_weights[:, k].mean().item()
print(f"{k:<6} {p_on_k:<8.4f} {strength_k:<10.4f} {gated_k:<8.4f} {weight_k:<12.6f}")
print("-" * 80)
print(f"{'TOTAL':<6} {p_on.mean().item():<8.4f}")
def visualize_training_curves(epoch_losses, epoch_metrics):
"""Print summary of training curves."""
print("\n" + "=" * 80)
print("TRAINING CURVES SUMMARY")
print("=" * 80)
epochs = len(epoch_losses["total"])
print(f"\nTotal Loss Progression (first → last):")
print(f" {epoch_losses['total'][0]:.6f} → {epoch_losses['total'][-1]:.6f}")
print(f"\nTask Loss Progression:")
print(f" {epoch_losses['task'][0]:.6f} → {epoch_losses['task'][-1]:.6f}")
print(f"\nCritic MSE Progression:")
print(f" {epoch_metrics['critic_mse'][0]:.6f} → {epoch_metrics['critic_mse'][-1]:.6f}")
print(f"\nExpected Spawns Progression:")
print(f" {epoch_metrics['exp_spawns'][0]:.2f} → {epoch_metrics['exp_spawns'][-1]:.2f}")
print(f"\nTemperature Progression (annealing):")
print(f" {epoch_metrics['temp'][0]:.3f} → {epoch_metrics['temp'][-1]:.3f}")
# -------------------------
# Main
# -------------------------
if __name__ == "__main__":
print(f"Device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")
# Train
model, epoch_losses, epoch_metrics = train_improved()
# Analyze
X, y = synthetic_data(n=40000, d=32)
analyze_node_activation(model, X, y)
# Visualize
visualize_training_curves(epoch_losses, epoch_metrics)
print("\n" + "=" * 80)
print("Training complete!")
print("=" * 80)
@lastforkbender
Copy link
Copy Markdown
Author

Never ever believe in nothing, especially if you are standing in front of absolutely no nothing at all and you have awareness you are going to be allowed to go thru the Eye of Needle as a camel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment