Created
May 21, 2026 15:40
-
-
Save lastforkbender/273dd46ac5168c98c25e854f21cccdb2 to your computer and use it in GitHub Desktop.
Stedman NN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # stedman_improved.py | |
| # Enhanced production Stedman with adaptive polynomial composition | |
| # via cross-recursive dynamics of rotational node strengths | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import time | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| DTYPE = torch.float32 | |
| # ------------------------- | |
| # Utilities | |
| # ------------------------- | |
| def sample_gumbel(shape, device, eps=1e-20): | |
| u = torch.rand(shape, device=device) | |
| return -torch.log(-torch.log(u + eps) + eps) | |
| def gumbel_softmax_sample(logits, temp=1.0): | |
| g = sample_gumbel(logits.shape, logits.device) | |
| y = (logits + g) / temp | |
| return F.softmax(y, dim=-1) | |
| def cayley_orthogonal_from_skew(S, rcond=1e-5, max_iters=3): | |
| """ | |
| S: [..., d, d] skew-symmetric | |
| Cayley: Q = (I - S)^{-1} (I + S) | |
| Enhanced with fallback for near-singular cases. | |
| """ | |
| d = S.size(-1) | |
| I = torch.eye(d, device=S.device, dtype=S.dtype).expand_as(S) | |
| I_minus_S = I - S | |
| I_plus_S = I + S | |
| try: | |
| # Primary: direct solve | |
| X = torch.linalg.solve(I_minus_S, I_plus_S) | |
| except RuntimeError: | |
| # Fallback: iterative refinement or lstsq | |
| X = torch.linalg.lstsq(I_minus_S, I_plus_S).solution | |
| # Optional: iterative polish | |
| for _ in range(max_iters): | |
| residual = I_minus_S @ X - I_plus_S | |
| if residual.abs().max() < 1e-4: | |
| break | |
| delta = torch.linalg.lstsq(I_minus_S, residual).solution | |
| X = X - delta | |
| return X | |
| def make_skew_from_params(P): | |
| """P: [..., d, d] -> skew S = P - P^T""" | |
| return P - P.transpose(-1, -2) | |
| def compute_spectral_norm_batch(M, n_power_iters=2): | |
| """ | |
| Compute approximate spectral norm (largest singular value) via power iteration. | |
| M: [B, d, d] or [d, d] | |
| Returns: [B] or scalar | |
| """ | |
| if M.dim() == 2: | |
| M = M.unsqueeze(0) | |
| B, d, _ = M.shape | |
| v = torch.randn(B, d, 1, device=M.device, dtype=M.dtype) | |
| v = v / (v.norm(dim=1, keepdim=True) + 1e-8) | |
| for _ in range(n_power_iters): | |
| u = torch.bmm(M, v) | |
| u = u / (u.norm(dim=1, keepdim=True) + 1e-8) | |
| v = torch.bmm(M.transpose(1, 2), u) | |
| v = v / (v.norm(dim=1, keepdim=True) + 1e-8) | |
| sv = (torch.bmm(M, v) * u).sum(dim=(1, 2)) | |
| return sv.squeeze() if B == 1 else sv | |
| # ------------------------- | |
| # Batched randomized SVD (GPU-friendly) | |
| # ------------------------- | |
| def batched_randomized_range(A, rank, n_iter=1): | |
| N, d = A.shape | |
| device = A.device | |
| P = torch.randn(d, rank, device=device, dtype=A.dtype) | |
| Y = A @ P | |
| for _ in range(n_iter): | |
| Z = A.T @ Y | |
| Y = A @ Z | |
| Q, _ = torch.linalg.qr(Z) | |
| return Q[:, :rank] | |
| class BatchedRandSVD(nn.Module): | |
| def __init__(self, rank=64, n_iter=2): | |
| super().__init__() | |
| self.rank = rank | |
| self.n_iter = n_iter | |
| def forward(self, A): | |
| B, d = A.shape | |
| Q = batched_randomized_range(A, self.rank, self.n_iter) | |
| z = A @ Q | |
| sv = z.norm(dim=0) / math.sqrt(max(1, B)) | |
| return z, sv, Q | |
| # ------------------------- | |
| # Cross-Recursive Polynomial Compositor (CRPC) | |
| # Leverages rotational node strengths for adaptive polynomial expansion | |
| # ------------------------- | |
| class CrossRecursivePolyCompositor(nn.Module): | |
| """ | |
| Recursively refines latent codes z using: | |
| 1. Node strengths (spectral norms of Cayley orthogonals) | |
| 2. Spawn probabilities (which children are active) | |
| 3. Gated polynomial expansions per child | |
| This enables "cognitive composition" of high-scoring nodes near rotational strengths | |
| even when spawning is minimal. | |
| """ | |
| def __init__(self, rank, y_dim, k_max, degree=2, hidden=None, recurse_depth=2): | |
| super().__init__() | |
| self.rank = rank | |
| self.y_dim = y_dim | |
| self.k_max = k_max | |
| self.degree = degree | |
| self.recurse_depth = recurse_depth | |
| # Per-child polynomial readouts (depth-wise) | |
| self.child_poly_layers = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Linear(rank, hidden or rank), | |
| nn.GELU(), | |
| nn.Linear(hidden or rank, y_dim) | |
| ) for _ in range(k_max) | |
| ]) | |
| # Node strength weighting (gated) | |
| self.node_strength_gate = nn.Sequential( | |
| nn.Linear(1, 16), | |
| nn.GELU(), | |
| nn.Linear(16, 1), | |
| nn.Sigmoid() | |
| ) | |
| # Recursive refinement layers | |
| self.refine_layers = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Linear(rank + y_dim, rank), | |
| nn.LayerNorm(rank), | |
| nn.GELU() | |
| ) for _ in range(recurse_depth) | |
| ]) | |
| # Final composition (adaptive polynomial) | |
| in_feat = y_dim + rank * degree | |
| self.final_compose = nn.Sequential( | |
| nn.Linear(in_feat, hidden or 128), | |
| nn.LayerNorm(hidden or 128), | |
| nn.GELU(), | |
| nn.Linear(hidden or 128, y_dim) | |
| ) | |
| def forward(self, z, O, p_on, sv): | |
| """ | |
| z: [B, r] latent codes from randomized SVD | |
| O: [k, d_h, d_h] Cayley orthogonal transforms (children) | |
| p_on: [B, k] spawn probabilities | |
| sv: [r] singular values | |
| Returns: y_refined [B, y_dim], composition_stats dict | |
| """ | |
| B, r = z.shape | |
| k = O.shape[0] | |
| device = z.device | |
| # Compute node strengths: spectral norms of child transforms | |
| # This tells us how "strong" each rotational child is | |
| node_strengths = compute_spectral_norm_batch(O, n_power_iters=2) # [k] | |
| node_strengths = node_strengths / (node_strengths.max() + 1e-8) | |
| # Gated node strengths using spawn probability | |
| gated_strengths = self.node_strength_gate( | |
| node_strengths.unsqueeze(-1) | |
| ).squeeze(-1) # [k] | |
| # Weighted combination: active children with high node strengths contribute more | |
| child_weights = p_on * gated_strengths.unsqueeze(0) # [B, k] | |
| child_weights = child_weights / (child_weights.sum(dim=1, keepdim=True) + 1e-8) | |
| # Per-child polynomial outputs | |
| child_outputs = torch.stack([ | |
| poly_layer(z) for poly_layer in self.child_poly_layers | |
| ], dim=1) # [B, k, y_dim] | |
| # Weighted composition | |
| y_init = (child_weights.unsqueeze(-1) * child_outputs).sum(dim=1) # [B, y_dim] | |
| # Recursive refinement with polynomial features | |
| z_refined = z.clone() | |
| poly_features = [] | |
| for depth in range(self.recurse_depth): | |
| # Polynomial expansion of current latent code | |
| poly_terms = [torch.ones(B, 1, device=device, dtype=z.dtype)] | |
| poly_terms.append(z_refined) | |
| if self.degree >= 2: | |
| poly_terms.append(z_refined * z_refined) | |
| if self.degree >= 3: | |
| poly_terms.append(z_refined * z_refined * z_refined) | |
| poly_feat = torch.cat(poly_terms, dim=1) # [B, 1+r+(r if deg>=2)+...] | |
| poly_features.append(poly_feat) | |
| # Refine z using current prediction and features | |
| refine_input = torch.cat([z_refined, y_init], dim=1) | |
| z_refined = self.refine_layers[depth](refine_input) | |
| # Aggregate polynomial features across recursion depth | |
| poly_agg = torch.cat(poly_features, dim=1) # [B, recurse_depth * (1+r+...)] | |
| # Final composition: combine weighted child output with recursive polynomial features | |
| final_input = torch.cat([y_init, poly_agg[:, self.rank:]], dim=1) # skip constant term | |
| y_refined = self.final_compose(final_input) | |
| # Composition statistics | |
| stats = { | |
| "node_strengths": node_strengths, | |
| "gated_strengths": gated_strengths, | |
| "child_weights": child_weights, | |
| "y_init_norm": y_init.norm(dim=1).mean().item(), | |
| "z_refined_norm": z_refined.norm(dim=1).mean().item(), | |
| } | |
| return y_refined, stats | |
| # ------------------------- | |
| # Adaptive Spawn Unit v2 (fixed & enhanced) | |
| # ------------------------- | |
| class AdaptiveSpawnV2(nn.Module): | |
| def __init__(self, d_in, d_hidden, k_max=8, temp_init=0.5, hard=False): | |
| super().__init__() | |
| self.d_in = d_in | |
| self.d_hidden = d_hidden | |
| self.k_max = k_max | |
| self.temp = nn.Parameter(torch.tensor(float(temp_init))) | |
| self.hard = hard | |
| self.gate_logits = nn.Linear(d_in, k_max * 2) | |
| self.base_proj = nn.Linear(d_in, d_hidden, bias=True) | |
| self.child_params = nn.Parameter(torch.randn(k_max, d_hidden, d_hidden) * 0.01) | |
| self.child_scale = nn.Parameter(torch.ones(k_max, 1)) | |
| def forward(self, x): | |
| B = x.shape[0] | |
| logits = self.gate_logits(x).view(B, self.k_max, 2) | |
| gumb = gumbel_softmax_sample(logits, temp=torch.clamp(self.temp, min=1e-3)) | |
| p_on = gumb[..., 1] | |
| if self.hard: | |
| hard_mask = (p_on > 0.5).float() | |
| p_on = (hard_mask - p_on).detach() + p_on | |
| base = self.base_proj(x) | |
| P = self.child_params | |
| S = make_skew_from_params(P) | |
| O = cayley_orthogonal_from_skew(S) # [k, d_h, d_h] - robust | |
| children = torch.einsum("bd,kde->bke", base, O) | |
| children = children * self.child_scale.unsqueeze(0) | |
| weights = p_on.unsqueeze(-1) | |
| agg = (weights * children).sum(dim=1) | |
| stats = { | |
| "p_on_mean": p_on.mean(dim=0), | |
| "expected_spawns": p_on.sum(dim=1).mean(), | |
| "p_on": p_on, # Expose for CRPC | |
| "O": O, # Expose orthogonals for CRPC | |
| } | |
| return agg, stats | |
| # ------------------------- | |
| # Metacritic (unchanged) | |
| # ------------------------- | |
| class MetacriticProd(nn.Module): | |
| def __init__(self, y_dim, sv_dim, spawn_dim, hidden=256): | |
| super().__init__() | |
| in_dim = y_dim + sv_dim + spawn_dim | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, hidden), | |
| nn.LayerNorm(hidden), | |
| nn.GELU(), | |
| nn.Linear(hidden, hidden//2), | |
| nn.GELU(), | |
| nn.Linear(hidden//2, 1) | |
| ) | |
| def forward(self, y_hat, sv, spawn_stats): | |
| B = y_hat.shape[0] | |
| sv_b = sv.unsqueeze(0).expand(B, -1) | |
| p_on_mean = spawn_stats["p_on_mean"].to(y_hat.device) | |
| expected_spawns = spawn_stats["expected_spawns"] | |
| if not isinstance(expected_spawns, torch.Tensor): | |
| expected_spawns = torch.tensor(expected_spawns, device=y_hat.device, dtype=y_hat.dtype) | |
| spawn_vec = torch.cat([p_on_mean, expected_spawns.unsqueeze(0)], dim=0).unsqueeze(0).expand(B, -1) | |
| inp = torch.cat([y_hat, sv_b, spawn_vec], dim=1) | |
| return self.net(inp).squeeze(-1) | |
| # ------------------------- | |
| # Full Improved Model | |
| # ------------------------- | |
| class StedmanImproved(nn.Module): | |
| def __init__(self, d_in, d_hidden, y_dim, k_max=8, rank=64, poly_deg=2, recurse_depth=2): | |
| super().__init__() | |
| self.backbone = nn.Sequential( | |
| nn.Linear(d_in, d_hidden), | |
| nn.GELU(), | |
| nn.Linear(d_hidden, d_hidden), | |
| nn.LayerNorm(d_hidden), # Added for stability | |
| nn.GELU() | |
| ) | |
| self.asu = AdaptiveSpawnV2(d_hidden, d_hidden, k_max=k_max, temp_init=0.8, hard=False) | |
| self.proj = BatchedRandSVD(rank=rank, n_iter=2) | |
| # Cross-recursive polynomial compositor (NEW) | |
| self.crpc = CrossRecursivePolyCompositor( | |
| rank=rank, y_dim=y_dim, k_max=k_max, | |
| degree=poly_deg, hidden=d_hidden//2, recurse_depth=recurse_depth | |
| ) | |
| self.critic = MetacriticProd(y_dim=y_dim, sv_dim=rank, spawn_dim=k_max+1, hidden=256) | |
| self.k_max = k_max | |
| def forward(self, x): | |
| feat = self.backbone(x) | |
| spawned, stats = self.asu(feat) | |
| z, sv, Q = self.proj(spawned) | |
| # Use CRPC instead of simple poly readout | |
| y_hat, crpc_stats = self.crpc(z, stats["O"], stats["p_on"], sv) | |
| # Merge stats | |
| stats.update(crpc_stats) | |
| c_hat = self.critic(y_hat, sv, stats) | |
| return y_hat, c_hat, stats | |
| # ------------------------- | |
| # Training utilities | |
| # ------------------------- | |
| def spectral_penalty(module_list, coeff=1e-3): | |
| pen = 0.0 | |
| for m in module_list: | |
| if hasattr(m, "child_params"): | |
| pen = pen + (m.child_params.view(-1)**2).sum() | |
| return coeff * pen | |
| class LambdaScheduler: | |
| """Adaptive critic weight scheduling""" | |
| def __init__(self, lambda_init=0.5, lambda_max=2.0, warmup_epochs=2, total_epochs=6): | |
| self.lambda_init = lambda_init | |
| self.lambda_max = lambda_max | |
| self.warmup_epochs = warmup_epochs | |
| self.total_epochs = total_epochs | |
| def get_lambda(self, epoch): | |
| if epoch < self.warmup_epochs: | |
| # Linear warmup | |
| return self.lambda_init + (self.lambda_max - self.lambda_init) * epoch / self.warmup_epochs | |
| else: | |
| # Cosine annealing after warmup | |
| progress = (epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs) | |
| return self.lambda_max * (1 + math.cos(math.pi * progress)) / 2 | |
| # ------------------------- | |
| # Synthetic data | |
| # ------------------------- | |
| def synthetic_data(n=20000, d=32, noise=0.08): | |
| X = torch.randn(n, d) | |
| W = torch.randn(d, 1) | |
| y = X @ W + 0.6 * torch.sin(X.sum(dim=1, keepdim=True) * 0.7) + noise * torch.randn(n, 1) | |
| return X, y | |
| # ------------------------- | |
| # Training loop (improved) | |
| # ------------------------- | |
| def train_improved(): | |
| """ | |
| Production training with: | |
| - Mixed precision (AMP) | |
| - Gradient clipping | |
| - Learning rate scheduling | |
| - Dynamic critic weight (lambda scheduling) | |
| - Enhanced logging | |
| """ | |
| torch.manual_seed(1) | |
| d_in = 32 | |
| d_hidden = 256 | |
| y_dim = 1 | |
| k_max = 8 | |
| rank = 64 | |
| model = StedmanImproved( | |
| d_in, d_hidden, y_dim, | |
| k_max=k_max, rank=rank, poly_deg=2, recurse_depth=2 | |
| ).to(DEVICE) | |
| opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( | |
| opt, T_0=2, T_mult=1.5, eta_min=1e-5 | |
| ) | |
| scaler = torch.cuda.amp.GradScaler() | |
| X, y = synthetic_data(n=40000, d=d_in) | |
| X, y = X.to(DEVICE), y.to(DEVICE) | |
| bs = 512 | |
| epochs = 6 | |
| gamma = 0.02 # spawn cost coefficient | |
| lambda_sched = LambdaScheduler(lambda_init=0.5, lambda_max=1.5, warmup_epochs=2, total_epochs=epochs) | |
| # Tracking metrics | |
| epoch_losses = {"task": [], "critic": [], "spawn": [], "reg": [], "total": []} | |
| epoch_metrics = {"critic_mse": [], "exp_spawns": [], "temp": []} | |
| print("=" * 80) | |
| print(f"Training StedmanImproved on {X.shape[0]} samples, d={d_in}, rank={rank}, k_max={k_max}") | |
| print("=" * 80) | |
| for ep in range(epochs): | |
| perm = torch.randperm(X.size(0)) | |
| t0 = time.time() | |
| batch_losses = {"task": 0.0, "critic": 0.0, "spawn": 0.0, "reg": 0.0, "total": 0.0} | |
| n_batches = 0 | |
| lambda_c = lambda_sched.get_lambda(ep) | |
| for i in range(0, X.size(0), bs): | |
| idx = perm[i:i+bs] | |
| xb, yb = X[idx], y[idx] | |
| with torch.cuda.amp.autocast(): | |
| y_hat, c_hat, stats = model(xb) | |
| # Task loss | |
| task_loss = F.mse_loss(y_hat, yb) | |
| # Critic loss (predict absolute error) | |
| abs_err = (y_hat - yb).abs().squeeze(-1) | |
| critic_loss = F.mse_loss(c_hat, abs_err) | |
| # Spawn cost (penalize high expected spawns) | |
| sp_cost = gamma * stats["expected_spawns"] | |
| # Spectral regularization | |
| reg = spectral_penalty([model.asu], coeff=1e-6) | |
| # Total loss | |
| loss = task_loss + lambda_c * critic_loss + sp_cost + reg | |
| opt.zero_grad() | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(opt) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0) | |
| scaler.step(opt) | |
| scaler.update() | |
| # Accumulate metrics | |
| batch_losses["task"] += task_loss.item() * xb.size(0) | |
| batch_losses["critic"] += critic_loss.item() * xb.size(0) | |
| batch_losses["spawn"] += sp_cost.item() * xb.size(0) | |
| batch_losses["reg"] += reg.item() * xb.size(0) | |
| batch_losses["total"] += loss.item() * xb.size(0) | |
| n_batches += 1 | |
| # Normalize by dataset size | |
| n_samples = X.size(0) | |
| for key in batch_losses: | |
| batch_losses[key] /= n_samples | |
| # Update scheduler | |
| scheduler.step() | |
| # Evaluation on full dataset | |
| with torch.no_grad(): | |
| y_hat_full, c_hat_full, stats_full = model(X[:2048]) | |
| abs_err_full = (y_hat_full - y[:2048]).abs().squeeze(-1) | |
| critic_mse = F.mse_loss(c_hat_full, abs_err_full).item() | |
| # Log metrics | |
| epoch_losses["task"].append(batch_losses["task"]) | |
| epoch_losses["critic"].append(batch_losses["critic"]) | |
| epoch_losses["spawn"].append(batch_losses["spawn"]) | |
| epoch_losses["reg"].append(batch_losses["reg"]) | |
| epoch_losses["total"].append(batch_losses["total"]) | |
| epoch_metrics["critic_mse"].append(critic_mse) | |
| epoch_metrics["exp_spawns"].append(stats_full["expected_spawns"].item()) | |
| epoch_metrics["temp"].append(float(model.asu.temp)) | |
| print( | |
| f"Epoch {ep+1:2d}/{epochs} | " | |
| f"Loss {batch_losses['total']:.6f} (task={batch_losses['task']:.4f}, " | |
| f"critic={batch_losses['critic']:.4f}, spawn={batch_losses['spawn']:.4f}) | " | |
| f"λ_c={lambda_c:.2f} | Temp={float(model.asu.temp):.3f} | " | |
| f"Spawns={stats_full['expected_spawns'].item():.2f} | " | |
| f"Time {time.time()-t0:.2f}s" | |
| ) | |
| print("\n" + "=" * 80) | |
| print("FINAL EVALUATION") | |
| print("=" * 80) | |
| with torch.no_grad(): | |
| y_hat, c_hat, stats = model(X[:4096]) | |
| abs_err = (y_hat - y[:4096]).abs().squeeze(-1) | |
| task_mse = F.mse_loss(y_hat, y[:4096]).item() | |
| critic_mse = F.mse_loss(c_hat, abs_err).item() | |
| print(f"Task MSE: {task_mse:.6f}") | |
| print(f"Critic MSE: {critic_mse:.6f}") | |
| print(f"Expected spawns: {stats['expected_spawns'].item():.2f} / {k_max}") | |
| print(f"Spawn probability: {stats['p_on_mean'].mean().item():.4f}") | |
| print(f"Node strengths: min={stats['node_strengths'].min():.4f}, " | |
| f"max={stats['node_strengths'].max():.4f}, " | |
| f"mean={stats['node_strengths'].mean():.4f}") | |
| print(f"Gated strengths: min={stats['gated_strengths'].min():.4f}, " | |
| f"max={stats['gated_strengths'].max():.4f}, " | |
| f"mean={stats['gated_strengths'].mean():.4f}") | |
| print(f"Latent norm (z): {stats['z_refined_norm']:.4f}") | |
| print(f"Output norm (y): {stats['y_init_norm']:.4f}") | |
| return model, epoch_losses, epoch_metrics | |
| # ------------------------- | |
| # Analysis utilities | |
| # ------------------------- | |
| def analyze_node_activation(model, X, y, device=DEVICE): | |
| """ | |
| Analyze which nodes (children) are actively used and their interaction patterns. | |
| """ | |
| model.eval() | |
| with torch.no_grad(): | |
| y_hat, c_hat, stats = model(X[:1024].to(device)) | |
| p_on = stats["p_on"] # [B, k] | |
| node_strengths = stats["node_strengths"] | |
| gated_strengths = stats["gated_strengths"] | |
| child_weights = stats["child_weights"] | |
| print("\n" + "=" * 80) | |
| print("NODE ACTIVATION ANALYSIS") | |
| print("=" * 80) | |
| print(f"{'Node':<6} {'P(On)':<8} {'Strength':<10} {'Gated':<8} {'Avg Weight':<12}") | |
| print("-" * 80) | |
| for k in range(model.k_max): | |
| p_on_k = p_on[:, k].mean().item() | |
| strength_k = node_strengths[k].item() | |
| gated_k = gated_strengths[k].item() | |
| weight_k = child_weights[:, k].mean().item() | |
| print(f"{k:<6} {p_on_k:<8.4f} {strength_k:<10.4f} {gated_k:<8.4f} {weight_k:<12.6f}") | |
| print("-" * 80) | |
| print(f"{'TOTAL':<6} {p_on.mean().item():<8.4f}") | |
| def visualize_training_curves(epoch_losses, epoch_metrics): | |
| """Print summary of training curves.""" | |
| print("\n" + "=" * 80) | |
| print("TRAINING CURVES SUMMARY") | |
| print("=" * 80) | |
| epochs = len(epoch_losses["total"]) | |
| print(f"\nTotal Loss Progression (first → last):") | |
| print(f" {epoch_losses['total'][0]:.6f} → {epoch_losses['total'][-1]:.6f}") | |
| print(f"\nTask Loss Progression:") | |
| print(f" {epoch_losses['task'][0]:.6f} → {epoch_losses['task'][-1]:.6f}") | |
| print(f"\nCritic MSE Progression:") | |
| print(f" {epoch_metrics['critic_mse'][0]:.6f} → {epoch_metrics['critic_mse'][-1]:.6f}") | |
| print(f"\nExpected Spawns Progression:") | |
| print(f" {epoch_metrics['exp_spawns'][0]:.2f} → {epoch_metrics['exp_spawns'][-1]:.2f}") | |
| print(f"\nTemperature Progression (annealing):") | |
| print(f" {epoch_metrics['temp'][0]:.3f} → {epoch_metrics['temp'][-1]:.3f}") | |
| # ------------------------- | |
| # Main | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| print(f"Device: {DEVICE}") | |
| print(f"PyTorch version: {torch.__version__}") | |
| # Train | |
| model, epoch_losses, epoch_metrics = train_improved() | |
| # Analyze | |
| X, y = synthetic_data(n=40000, d=32) | |
| analyze_node_activation(model, X, y) | |
| # Visualize | |
| visualize_training_curves(epoch_losses, epoch_metrics) | |
| print("\n" + "=" * 80) | |
| print("Training complete!") | |
| print("=" * 80) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Never ever believe in nothing, especially if you are standing in front of absolutely no nothing at all and you have awareness you are going to be allowed to go thru the Eye of Needle as a camel.