Skip to content

Instantly share code, notes, and snippets.

@akarnokd
Created April 25, 2026 05:12
Show Gist options
  • Select an option

  • Save akarnokd/91c7d1ffa0130435f59861e77dbf0812 to your computer and use it in GitHub Desktop.

Select an option

Save akarnokd/91c7d1ffa0130435f59861e77dbf0812 to your computer and use it in GitHub Desktop.
package hu.akarnokd.math;
import java.util.Random;
import org.apache.commons.math3.complex.*;
import org.apache.commons.math3.linear.*;
import org.apache.commons.math3.util.FastMath;
public class EmlFieldMatrixRegression {
private static final int DIM = 3;
private static final int N_SAMPLES = 200;
private static final int MAX_EPOCHS_PER_DEPTH = 400;
private static final int MIN_DEPTH = 0;
private static final int MAX_DEPTH = 12;
private static final double LEARNING_RATE = 0.001;
private static final double CLIP = 8.0;
// ============== TOGGLE HERE ==============
private static final boolean USE_ORIGINAL_EXOTIC_ATOM = false; // set true for e^X - ln(Y)
private static double smoothClip(double x) {
return CLIP * FastMath.tanh(x / CLIP * 5.0);
}
// ==================== COMPUTATION ATOM ====================
private static FieldMatrix<Complex> eml_M(FieldMatrix<Complex> Xc, FieldMatrix<Complex> Yc) {
if (USE_ORIGINAL_EXOTIC_ATOM) {
// Original: safe e^X - ln(Y)
RealMatrix X = realPart(Xc);
RealMatrix Y = realPart(Yc);
double[][] data = new double[X.getRowDimension()][X.getColumnDimension()];
for (int i = 0; i < X.getRowDimension(); i++) {
for (int j = 0; j < X.getColumnDimension(); j++) {
double x = X.getEntry(i, j);
double y = Y.getEntry(i, j);
double expX = FastMath.exp(x);
double lnY;
if (y > 1e-8) {
lnY = FastMath.log(y);
} else if (y < -1e-8) {
lnY = FastMath.log(-y); // real part only
} else {
lnY = -20.0; // safe large negative
}
data[i][j] = smoothClip(expX - lnY);
}
}
return toComplex(MatrixUtils.createRealMatrix(data));
} else {
// Simple stable atom (good default)
RealMatrix X = realPart(Xc);
RealMatrix Y = realPart(Yc);
double[][] data = new double[X.getRowDimension()][X.getColumnDimension()];
for (int i = 0; i < X.getRowDimension(); i++) {
for (int j = 0; j < X.getColumnDimension(); j++) {
double val = X.getEntry(i, j) - 0.5 * Y.getEntry(i, j);
data[i][j] = smoothClip(val);
}
}
return toComplex(MatrixUtils.createRealMatrix(data));
}
}
private static RealMatrix realPart(FieldMatrix<Complex> c) {
int rows = c.getRowDimension();
int cols = c.getColumnDimension();
double[][] data = new double[rows][cols];
for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) {
data[i][j] = c.getEntry(i, j).getReal();
}
return MatrixUtils.createRealMatrix(data);
}
private static FieldMatrix<Complex> toComplex(RealMatrix real) {
int rows = real.getRowDimension();
int cols = real.getColumnDimension();
Complex[][] data = new Complex[rows][cols];
for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) {
data[i][j] = new Complex(real.getEntry(i, j), 0.0);
}
return MatrixUtils.createFieldMatrix(data);
}
static class EMLNet {
private final int dim;
private final RealMatrix[] W;
private final RealMatrix readout;
private final RealMatrix[] velocityW;
private final RealMatrix velocityReadout;
private final Random rand = new Random(42);
public EMLNet(int dim, int depth) {
this.dim = dim;
this.W = new RealMatrix[depth];
this.velocityW = new RealMatrix[depth];
this.velocityReadout = MatrixUtils.createRealIdentityMatrix(dim);
for (int i = 0; i < depth; i++) {
W[i] = MatrixUtils.createRealIdentityMatrix(dim);
velocityW[i] = MatrixUtils.createRealIdentityMatrix(dim);
for (int r = 0; r < dim; r++) {
for (int c = 0; c < dim; c++) {
W[i].setEntry(r, c, 0.015 * (rand.nextDouble() - 0.5));
}
}
}
this.readout = MatrixUtils.createRealIdentityMatrix(dim);
for (int r = 0; r < dim; r++) {
for (int c = 0; c < dim; c++) {
readout.setEntry(r, c, 0.05 * (rand.nextDouble() - 0.5));
}
}
}
public RealMatrix forward(RealMatrix X_batch) {
int batch = X_batch.getRowDimension();
RealMatrix outputs = MatrixUtils.createRealMatrix(batch, dim);
for (int i = 0; i < batch; i++) {
double[] row = X_batch.getRow(i);
double[][] mData = new double[dim][dim];
for (int d = 0; d < dim; d++) mData[d][d] = row[d];
RealMatrix M = MatrixUtils.createRealMatrix(mData);
for (RealMatrix w : W) {
M = M.multiply(w);
FieldMatrix<Complex> Mc = toComplex(M);
M = realPart(eml_M(Mc, toComplex(MatrixUtils.createRealIdentityMatrix(dim))));
}
for (int d = 0; d < dim; d++) {
double val = M.getEntry(d, d);
val = FastMath.max(-100.0, FastMath.min(100.0, val));
outputs.setEntry(i, d, val);
}
}
outputs = outputs.multiply(readout);
return outputs;
}
public double computeLoss(RealMatrix X, RealMatrix b) {
return computeMSE(b, forward(X));
}
public void updateParameters(RealMatrix X, RealMatrix b, double lr, int epoch) {
final double h = 1e-4;
final double momentum = 0.9;
final double noiseScale = 0.003 * Math.exp(-epoch / 300.0);
for (int wi = 0; wi < W.length; wi++) {
RealMatrix w = W[wi];
RealMatrix vel = velocityW[wi];
for (int r = 0; r < dim; r++) {
for (int c = 0; c < dim; c++) {
double orig = w.getEntry(r, c);
w.setEntry(r, c, orig + h);
double lossPlus = computeLoss(X, b);
w.setEntry(r, c, orig - h);
double lossMinus = computeLoss(X, b);
double grad = (lossPlus - lossMinus) / (2 * h);
double newVel = momentum * vel.getEntry(r, c) - lr * grad;
vel.setEntry(r, c, newVel);
w.setEntry(r, c, orig + newVel);
}
}
}
for (int r = 0; r < dim; r++) {
for (int c = 0; c < dim; c++) {
double orig = readout.getEntry(r, c);
readout.setEntry(r, c, orig + h);
double lossPlus = computeLoss(X, b);
readout.setEntry(r, c, orig - h);
double lossMinus = computeLoss(X, b);
double grad = (lossPlus - lossMinus) / (2 * h);
double newVel = momentum * velocityReadout.getEntry(r, c) - lr * grad;
velocityReadout.setEntry(r, c, newVel);
readout.setEntry(r, c, orig + newVel);
}
}
if (epoch < 150) {
for (RealMatrix w : W) {
for (int r = 0; r < dim; r++) for (int c = 0; c < dim; c++) {
w.setEntry(r, c, w.getEntry(r, c) + noiseScale * rand.nextGaussian());
}
}
for (int r = 0; r < dim; r++) for (int c = 0; c < dim; c++) {
readout.setEntry(r, c, readout.getEntry(r, c) + noiseScale * rand.nextGaussian());
}
}
}
}
public static void main(String[] args) {
Random rand = new Random(42);
// Data generation
RealMatrix A_true = MatrixUtils.createRealMatrix(DIM, DIM);
for (int i = 0; i < DIM; i++) {
for (int j = 0; j < DIM; j++) A_true.setEntry(i, j, rand.nextGaussian() * 0.5);
}
RealMatrix X = MatrixUtils.createRealMatrix(N_SAMPLES, DIM);
RealMatrix b = MatrixUtils.createRealMatrix(N_SAMPLES, DIM);
for (int i = 0; i < N_SAMPLES; i++) {
for (int j = 0; j < DIM; j++) X.setEntry(i, j, rand.nextGaussian());
RealMatrix xi = MatrixUtils.createRowRealMatrix(X.getRow(i));
RealMatrix bi = xi.multiply(A_true);
for (int j = 0; j < DIM; j++) {
b.setEntry(i, j, bi.getEntry(0, j) + rand.nextGaussian() * 0.05);
}
}
// OLS baseline
RealMatrix Xt = X.transpose();
RealMatrix XtX = Xt.multiply(X);
RealMatrix XtX_inv = new LUDecomposition(XtX).getSolver().getInverse();
RealMatrix Xtb = Xt.multiply(b);
RealMatrix A_ols = XtX_inv.multiply(Xtb);
double mse_ols = computeMSE(b, X.multiply(A_ols));
System.out.printf("OLS MSE: %.6f%n", mse_ols);
// Outer loop over depth (0 = OLS)
double bestLoss = Double.POSITIVE_INFINITY;
int bestDepth = 0;
EMLNet bestModel = null;
for (int depth = MIN_DEPTH; depth <= MAX_DEPTH; depth++) {
System.out.printf("%n=== Trying depth = %d ===%n", depth);
if (depth == 0) {
System.out.printf("Depth 0 (pure OLS) loss: %.6f%n", mse_ols);
if (mse_ols < bestLoss) {
bestLoss = mse_ols;
bestDepth = 0;
System.out.println(" → New best! (OLS)");
}
continue;
}
// Train mixin
EMLNet model = new EMLNet(DIM, depth);
for (int epoch = 0; epoch < MAX_EPOCHS_PER_DEPTH; epoch++) {
double loss = model.computeLoss(X, b);
if (epoch % 50 == 0 || epoch == MAX_EPOCHS_PER_DEPTH - 1) {
System.out.printf(" Epoch %4d loss: %.6f%n", epoch, loss);
}
model.updateParameters(X, b, LEARNING_RATE, epoch);
}
double finalLoss = model.computeLoss(X, b);
System.out.printf("Depth %d FINAL loss: %.6f%n", depth, finalLoss);
if (finalLoss < bestLoss - 1e-5) {
bestLoss = finalLoss;
bestDepth = depth;
bestModel = model;
System.out.printf(" → New best! (depth %d)%n", bestDepth);
} else {
System.out.printf(" → No meaningful improvement. Stopping.%n");
break;
}
}
System.out.printf("%n=== Best result ===%n");
System.out.printf("Best depth = %d with loss = %.6f%n", bestDepth, bestLoss);
System.out.printf("OLS was %.6f%n", mse_ols);
if (bestDepth == 0) {
System.out.println("Winner: Pure OLS (as expected for linear data)");
} else if (USE_ORIGINAL_EXOTIC_ATOM) {
System.out.println("Winner: Mixin with original e^X - ln(Y) atom");
} else {
System.out.println("Winner: Mixin with simple stable atom");
}
}
private static double computeMSE(RealMatrix target, RealMatrix pred) {
double sum = 0.0;
for (int i = 0; i < target.getRowDimension(); i++) {
for (int j = 0; j < target.getColumnDimension(); j++) {
double diff = target.getEntry(i, j) - pred.getEntry(i, j);
sum += diff * diff;
}
}
return sum / (target.getRowDimension() * target.getColumnDimension());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment