Created
April 25, 2026 05:12
-
-
Save akarnokd/91c7d1ffa0130435f59861e77dbf0812 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package hu.akarnokd.math; | |
| import java.util.Random; | |
| import org.apache.commons.math3.complex.*; | |
| import org.apache.commons.math3.linear.*; | |
| import org.apache.commons.math3.util.FastMath; | |
| public class EmlFieldMatrixRegression { | |
| private static final int DIM = 3; | |
| private static final int N_SAMPLES = 200; | |
| private static final int MAX_EPOCHS_PER_DEPTH = 400; | |
| private static final int MIN_DEPTH = 0; | |
| private static final int MAX_DEPTH = 12; | |
| private static final double LEARNING_RATE = 0.001; | |
| private static final double CLIP = 8.0; | |
| // ============== TOGGLE HERE ============== | |
| private static final boolean USE_ORIGINAL_EXOTIC_ATOM = false; // set true for e^X - ln(Y) | |
| private static double smoothClip(double x) { | |
| return CLIP * FastMath.tanh(x / CLIP * 5.0); | |
| } | |
| // ==================== COMPUTATION ATOM ==================== | |
| private static FieldMatrix<Complex> eml_M(FieldMatrix<Complex> Xc, FieldMatrix<Complex> Yc) { | |
| if (USE_ORIGINAL_EXOTIC_ATOM) { | |
| // Original: safe e^X - ln(Y) | |
| RealMatrix X = realPart(Xc); | |
| RealMatrix Y = realPart(Yc); | |
| double[][] data = new double[X.getRowDimension()][X.getColumnDimension()]; | |
| for (int i = 0; i < X.getRowDimension(); i++) { | |
| for (int j = 0; j < X.getColumnDimension(); j++) { | |
| double x = X.getEntry(i, j); | |
| double y = Y.getEntry(i, j); | |
| double expX = FastMath.exp(x); | |
| double lnY; | |
| if (y > 1e-8) { | |
| lnY = FastMath.log(y); | |
| } else if (y < -1e-8) { | |
| lnY = FastMath.log(-y); // real part only | |
| } else { | |
| lnY = -20.0; // safe large negative | |
| } | |
| data[i][j] = smoothClip(expX - lnY); | |
| } | |
| } | |
| return toComplex(MatrixUtils.createRealMatrix(data)); | |
| } else { | |
| // Simple stable atom (good default) | |
| RealMatrix X = realPart(Xc); | |
| RealMatrix Y = realPart(Yc); | |
| double[][] data = new double[X.getRowDimension()][X.getColumnDimension()]; | |
| for (int i = 0; i < X.getRowDimension(); i++) { | |
| for (int j = 0; j < X.getColumnDimension(); j++) { | |
| double val = X.getEntry(i, j) - 0.5 * Y.getEntry(i, j); | |
| data[i][j] = smoothClip(val); | |
| } | |
| } | |
| return toComplex(MatrixUtils.createRealMatrix(data)); | |
| } | |
| } | |
| private static RealMatrix realPart(FieldMatrix<Complex> c) { | |
| int rows = c.getRowDimension(); | |
| int cols = c.getColumnDimension(); | |
| double[][] data = new double[rows][cols]; | |
| for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) { | |
| data[i][j] = c.getEntry(i, j).getReal(); | |
| } | |
| return MatrixUtils.createRealMatrix(data); | |
| } | |
| private static FieldMatrix<Complex> toComplex(RealMatrix real) { | |
| int rows = real.getRowDimension(); | |
| int cols = real.getColumnDimension(); | |
| Complex[][] data = new Complex[rows][cols]; | |
| for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) { | |
| data[i][j] = new Complex(real.getEntry(i, j), 0.0); | |
| } | |
| return MatrixUtils.createFieldMatrix(data); | |
| } | |
| static class EMLNet { | |
| private final int dim; | |
| private final RealMatrix[] W; | |
| private final RealMatrix readout; | |
| private final RealMatrix[] velocityW; | |
| private final RealMatrix velocityReadout; | |
| private final Random rand = new Random(42); | |
| public EMLNet(int dim, int depth) { | |
| this.dim = dim; | |
| this.W = new RealMatrix[depth]; | |
| this.velocityW = new RealMatrix[depth]; | |
| this.velocityReadout = MatrixUtils.createRealIdentityMatrix(dim); | |
| for (int i = 0; i < depth; i++) { | |
| W[i] = MatrixUtils.createRealIdentityMatrix(dim); | |
| velocityW[i] = MatrixUtils.createRealIdentityMatrix(dim); | |
| for (int r = 0; r < dim; r++) { | |
| for (int c = 0; c < dim; c++) { | |
| W[i].setEntry(r, c, 0.015 * (rand.nextDouble() - 0.5)); | |
| } | |
| } | |
| } | |
| this.readout = MatrixUtils.createRealIdentityMatrix(dim); | |
| for (int r = 0; r < dim; r++) { | |
| for (int c = 0; c < dim; c++) { | |
| readout.setEntry(r, c, 0.05 * (rand.nextDouble() - 0.5)); | |
| } | |
| } | |
| } | |
| public RealMatrix forward(RealMatrix X_batch) { | |
| int batch = X_batch.getRowDimension(); | |
| RealMatrix outputs = MatrixUtils.createRealMatrix(batch, dim); | |
| for (int i = 0; i < batch; i++) { | |
| double[] row = X_batch.getRow(i); | |
| double[][] mData = new double[dim][dim]; | |
| for (int d = 0; d < dim; d++) mData[d][d] = row[d]; | |
| RealMatrix M = MatrixUtils.createRealMatrix(mData); | |
| for (RealMatrix w : W) { | |
| M = M.multiply(w); | |
| FieldMatrix<Complex> Mc = toComplex(M); | |
| M = realPart(eml_M(Mc, toComplex(MatrixUtils.createRealIdentityMatrix(dim)))); | |
| } | |
| for (int d = 0; d < dim; d++) { | |
| double val = M.getEntry(d, d); | |
| val = FastMath.max(-100.0, FastMath.min(100.0, val)); | |
| outputs.setEntry(i, d, val); | |
| } | |
| } | |
| outputs = outputs.multiply(readout); | |
| return outputs; | |
| } | |
| public double computeLoss(RealMatrix X, RealMatrix b) { | |
| return computeMSE(b, forward(X)); | |
| } | |
| public void updateParameters(RealMatrix X, RealMatrix b, double lr, int epoch) { | |
| final double h = 1e-4; | |
| final double momentum = 0.9; | |
| final double noiseScale = 0.003 * Math.exp(-epoch / 300.0); | |
| for (int wi = 0; wi < W.length; wi++) { | |
| RealMatrix w = W[wi]; | |
| RealMatrix vel = velocityW[wi]; | |
| for (int r = 0; r < dim; r++) { | |
| for (int c = 0; c < dim; c++) { | |
| double orig = w.getEntry(r, c); | |
| w.setEntry(r, c, orig + h); | |
| double lossPlus = computeLoss(X, b); | |
| w.setEntry(r, c, orig - h); | |
| double lossMinus = computeLoss(X, b); | |
| double grad = (lossPlus - lossMinus) / (2 * h); | |
| double newVel = momentum * vel.getEntry(r, c) - lr * grad; | |
| vel.setEntry(r, c, newVel); | |
| w.setEntry(r, c, orig + newVel); | |
| } | |
| } | |
| } | |
| for (int r = 0; r < dim; r++) { | |
| for (int c = 0; c < dim; c++) { | |
| double orig = readout.getEntry(r, c); | |
| readout.setEntry(r, c, orig + h); | |
| double lossPlus = computeLoss(X, b); | |
| readout.setEntry(r, c, orig - h); | |
| double lossMinus = computeLoss(X, b); | |
| double grad = (lossPlus - lossMinus) / (2 * h); | |
| double newVel = momentum * velocityReadout.getEntry(r, c) - lr * grad; | |
| velocityReadout.setEntry(r, c, newVel); | |
| readout.setEntry(r, c, orig + newVel); | |
| } | |
| } | |
| if (epoch < 150) { | |
| for (RealMatrix w : W) { | |
| for (int r = 0; r < dim; r++) for (int c = 0; c < dim; c++) { | |
| w.setEntry(r, c, w.getEntry(r, c) + noiseScale * rand.nextGaussian()); | |
| } | |
| } | |
| for (int r = 0; r < dim; r++) for (int c = 0; c < dim; c++) { | |
| readout.setEntry(r, c, readout.getEntry(r, c) + noiseScale * rand.nextGaussian()); | |
| } | |
| } | |
| } | |
| } | |
| public static void main(String[] args) { | |
| Random rand = new Random(42); | |
| // Data generation | |
| RealMatrix A_true = MatrixUtils.createRealMatrix(DIM, DIM); | |
| for (int i = 0; i < DIM; i++) { | |
| for (int j = 0; j < DIM; j++) A_true.setEntry(i, j, rand.nextGaussian() * 0.5); | |
| } | |
| RealMatrix X = MatrixUtils.createRealMatrix(N_SAMPLES, DIM); | |
| RealMatrix b = MatrixUtils.createRealMatrix(N_SAMPLES, DIM); | |
| for (int i = 0; i < N_SAMPLES; i++) { | |
| for (int j = 0; j < DIM; j++) X.setEntry(i, j, rand.nextGaussian()); | |
| RealMatrix xi = MatrixUtils.createRowRealMatrix(X.getRow(i)); | |
| RealMatrix bi = xi.multiply(A_true); | |
| for (int j = 0; j < DIM; j++) { | |
| b.setEntry(i, j, bi.getEntry(0, j) + rand.nextGaussian() * 0.05); | |
| } | |
| } | |
| // OLS baseline | |
| RealMatrix Xt = X.transpose(); | |
| RealMatrix XtX = Xt.multiply(X); | |
| RealMatrix XtX_inv = new LUDecomposition(XtX).getSolver().getInverse(); | |
| RealMatrix Xtb = Xt.multiply(b); | |
| RealMatrix A_ols = XtX_inv.multiply(Xtb); | |
| double mse_ols = computeMSE(b, X.multiply(A_ols)); | |
| System.out.printf("OLS MSE: %.6f%n", mse_ols); | |
| // Outer loop over depth (0 = OLS) | |
| double bestLoss = Double.POSITIVE_INFINITY; | |
| int bestDepth = 0; | |
| EMLNet bestModel = null; | |
| for (int depth = MIN_DEPTH; depth <= MAX_DEPTH; depth++) { | |
| System.out.printf("%n=== Trying depth = %d ===%n", depth); | |
| if (depth == 0) { | |
| System.out.printf("Depth 0 (pure OLS) loss: %.6f%n", mse_ols); | |
| if (mse_ols < bestLoss) { | |
| bestLoss = mse_ols; | |
| bestDepth = 0; | |
| System.out.println(" → New best! (OLS)"); | |
| } | |
| continue; | |
| } | |
| // Train mixin | |
| EMLNet model = new EMLNet(DIM, depth); | |
| for (int epoch = 0; epoch < MAX_EPOCHS_PER_DEPTH; epoch++) { | |
| double loss = model.computeLoss(X, b); | |
| if (epoch % 50 == 0 || epoch == MAX_EPOCHS_PER_DEPTH - 1) { | |
| System.out.printf(" Epoch %4d loss: %.6f%n", epoch, loss); | |
| } | |
| model.updateParameters(X, b, LEARNING_RATE, epoch); | |
| } | |
| double finalLoss = model.computeLoss(X, b); | |
| System.out.printf("Depth %d FINAL loss: %.6f%n", depth, finalLoss); | |
| if (finalLoss < bestLoss - 1e-5) { | |
| bestLoss = finalLoss; | |
| bestDepth = depth; | |
| bestModel = model; | |
| System.out.printf(" → New best! (depth %d)%n", bestDepth); | |
| } else { | |
| System.out.printf(" → No meaningful improvement. Stopping.%n"); | |
| break; | |
| } | |
| } | |
| System.out.printf("%n=== Best result ===%n"); | |
| System.out.printf("Best depth = %d with loss = %.6f%n", bestDepth, bestLoss); | |
| System.out.printf("OLS was %.6f%n", mse_ols); | |
| if (bestDepth == 0) { | |
| System.out.println("Winner: Pure OLS (as expected for linear data)"); | |
| } else if (USE_ORIGINAL_EXOTIC_ATOM) { | |
| System.out.println("Winner: Mixin with original e^X - ln(Y) atom"); | |
| } else { | |
| System.out.println("Winner: Mixin with simple stable atom"); | |
| } | |
| } | |
| private static double computeMSE(RealMatrix target, RealMatrix pred) { | |
| double sum = 0.0; | |
| for (int i = 0; i < target.getRowDimension(); i++) { | |
| for (int j = 0; j < target.getColumnDimension(); j++) { | |
| double diff = target.getEntry(i, j) - pred.getEntry(i, j); | |
| sum += diff * diff; | |
| } | |
| } | |
| return sum / (target.getRowDimension() * target.getColumnDimension()); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment