Created
September 5, 2025 05:31
-
-
Save freakynit/192df36eb6139b1a2b481a642ac5b01a to your computer and use it in GitHub Desktop.
A Java implementation of 12 RISC-style primitive operations and a 2-2-1 neural network to learn XOR using backpropagation, inspired by LuminalAI, for educational purposes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Arrays; | |
import java.util.Random; | |
/** | |
* A small RISC-style primitive operation framework (inpsired by LuminalAI), | |
* implementing 12 primitive ops and a demo network. | |
* | |
* Inspiration: https://github.com/luminal-ai/luminal | |
* | |
* Ops: | |
* --- | |
* Unary - Log2, Exp2, Sin, Sqrt, Recip | |
* Binary - Add, Mul, Mod, LessThan | |
* Other - SumReduce, MaxReduce, Contiguous | |
* | |
* Demo: | |
* ---- | |
* Trains a small 2-2-1 network to learn XOR network and executes it. | |
*/ | |
public class Main { | |
public static class Tensor { | |
public final double[] data; | |
public Tensor(int n) { this.data = new double[n]; } | |
public Tensor(double... arr) { this.data = arr.clone(); } | |
public int len() { return data.length; } | |
public double get(int i) { return data[i]; } | |
public void set(int i, double v) { data[i] = v; } | |
@Override public String toString() { return Arrays.toString(data); } | |
} | |
public static class Ops { | |
// Unary (elementwise) | |
public static Tensor log2(Tensor a) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = Math.log(a.data[i]) / Math.log(2.0); | |
return out; | |
} | |
public static Tensor exp2(Tensor a) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = Math.pow(2.0, a.data[i]); | |
return out; | |
} | |
public static Tensor sin(Tensor a) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = Math.sin(a.data[i]); | |
return out; | |
} | |
public static Tensor sqrt(Tensor a) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = Math.sqrt(a.data[i]); | |
return out; | |
} | |
public static Tensor recip(Tensor a) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = 1.0 / a.data[i]; | |
return out; | |
} | |
// Binary (elementwise) - supports tensor-tensor and tensor-scalar | |
public static Tensor add(Tensor a, Tensor b) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] + b.data[i]; | |
return out; | |
} | |
public static Tensor add(Tensor a, double scalar) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] + scalar; | |
return out; | |
} | |
public static Tensor mul(Tensor a, Tensor b) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] * b.data[i]; | |
return out; | |
} | |
public static Tensor mul(Tensor a, double scalar) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] * scalar; | |
return out; | |
} | |
public static Tensor mod(Tensor a, Tensor b) { | |
if (a.len() != b.len()) throw new IllegalArgumentException("mod: length mismatch"); | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] % b.data[i]; | |
return out; | |
} | |
public static Tensor mod(Tensor a, double scalar) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] % scalar; | |
return out; | |
} | |
// LessThan returns 1.0 where a < b, else 0.0 | |
public static Tensor lessThan(Tensor a, Tensor b) { | |
if (a.len() != b.len()) throw new IllegalArgumentException("lessThan: length mismatch"); | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = (a.data[i] < b.data[i]) ? 1.0 : 0.0; | |
return out; | |
} | |
public static Tensor lessThan(Tensor a, double scalar) { | |
Tensor out = new Tensor(a.len()); | |
for (int i=0;i<a.len();++i) out.data[i] = (a.data[i] < scalar) ? 1.0 : 0.0; | |
return out; | |
} | |
// Other | |
// SumReduce: sum of a tensor -> scalar | |
public static double sumReduce(Tensor a) { | |
double s=0.0; for (double v : a.data) s+=v; return s; | |
} | |
// MaxReduce: max element | |
public static double maxReduce(Tensor a) { | |
double m = Double.NEGATIVE_INFINITY; for (double v : a.data) if (v>m) m=v; return m; | |
} | |
// Contiguous: return a contiguous copy of the tensor | |
public static Tensor contiguous(Tensor a) { | |
return new Tensor(a.data); | |
} | |
// Convenience helpers built from primitives | |
public static double dot(Tensor a, Tensor b) { | |
return sumReduce(mul(a,b)); | |
} | |
// Elementwise sigmoid built from primitives (uses exp2 and recip) | |
// sigmoid(x) = 1 / (1 + e^{-x}) | |
// e^{-x} = 2^{(-x)/ln2} so we use exp2 and mul | |
public static Tensor sigmoid(Tensor x) { | |
// sigmoid implemented from exp2 + recip primitives | |
final double invLn2 = 1.0 / Math.log(2.0); | |
Tensor neg = mul(x, -1.0); | |
Tensor scaled = mul(neg, invLn2); | |
Tensor eMinusX = exp2(scaled); | |
Tensor denom = add(eMinusX, 1.0); | |
return recip(denom); | |
} | |
// helper to compute elementwise sigmoid derivative: s*(1-s) | |
public static Tensor sigmoidGrad(Tensor s) { | |
Tensor oneMinus = new Tensor(s.len()); | |
for (int i=0;i<s.len();++i) oneMinus.data[i] = 1.0 - s.data[i]; | |
return mul(s, oneMinus); | |
} | |
// Simple ReLU built from lessThan and mul + add (relu(x) = x * (x>0)) | |
// using lessThan we produce negative mask then compute 1 - negativeMask to get positiveMask | |
public static Tensor relu(Tensor x) { | |
// negativeMask = lessThan(x, 0) -> 1.0 for negatives | |
Tensor negMask = lessThan(x, 0.0); | |
// positiveMask = 1 - negMask => implement as add with scalar 1 then add negMask*-1 | |
// simpler: create tensor of ones and subtract mask by multiplying by -1 and adding | |
Tensor ones = new Tensor(new double[x.len()]); | |
for (int i=0;i<ones.len();++i) ones.data[i] = 1.0; | |
// posMask = ones + (-1)*negMask | |
Tensor posMask = add(ones, mul(negMask, -1.0)); | |
return mul(x, posMask); | |
} | |
} | |
// Dense layer with a few utils for forward/backprop updates | |
public static class DenseLayer { | |
// weights => outDim x inDim | |
public final double[][] W; | |
public final double[] b; | |
public final int inDim, outDim; | |
public DenseLayer(int inDim, int outDim, Random rng) { | |
this.inDim = inDim; this.outDim = outDim; | |
W = new double[outDim][inDim]; | |
b = new double[outDim]; | |
for (int i=0;i<outDim;++i) { | |
for (int j=0;j<inDim;++j) W[i][j] = rng.nextGaussian() * 0.5; | |
b[i] = 0.0; | |
} | |
} | |
public Tensor forward(Tensor x) { | |
Tensor out = new Tensor(outDim); | |
for (int i=0;i<outDim;++i) { | |
double s=0.0; | |
for (int j=0;j<inDim;++j) s += W[i][j]*x.data[j]; | |
out.data[i] = s + b[i]; | |
} | |
return out; | |
} | |
// Apply gradient update given upstream delta for this layer's outputs. | |
// `deltaOut` is length outDim, and `prevAct` is the activations of previous layer (length inDim) | |
public void applyGradients(double[] deltaOut, double[] prevAct, double lr) { | |
for (int i=0;i<outDim;++i) { | |
for (int j=0;j<inDim;++j) { | |
W[i][j] -= lr * deltaOut[i] * prevAct[j]; | |
} | |
b[i] -= lr * deltaOut[i]; | |
} | |
} | |
} | |
public static void main(String[] args) { | |
Random rng = new Random(1234); | |
// XOR dataset (2 inputs -> 1 output) | |
Tensor[] inputs = { | |
new Tensor(0.0,0.0), new Tensor(0.0,1.0), new Tensor(1.0,0.0), new Tensor(1.0,1.0) | |
}; | |
double[] targets = {0.0, 1.0, 1.0, 0.0}; | |
// network `2 -> 2 -> 1` | |
DenseLayer l1 = new DenseLayer(2, 2, rng); | |
DenseLayer l2 = new DenseLayer(2, 1, rng); | |
double lr = 0.5; | |
int epochs = 20000; | |
for (int epoch=0; epoch<epochs; ++epoch) { | |
double loss = 0.0; | |
// Simple SGD over dataset | |
for (int k=0;k<inputs.length;++k) { | |
Tensor x = inputs[k]; | |
double t = targets[k]; | |
// forward pass | |
Tensor z1 = l1.forward(x); | |
Tensor a1 = Ops.sigmoid(z1); | |
Tensor z2 = l2.forward(a1); | |
Tensor a2 = Ops.sigmoid(z2); // network output (length 1) | |
double y = a2.data[0]; | |
double err = y - t; | |
loss += 0.5 * err * err; | |
// backward: output delta: dL/dz2 = (y - t) * sigmoid'(z2) ; sigmoid'(z2)=y*(1-y) | |
double delta2 = err * (y * (1.0 - y)); // scalar | |
// hidden deltas: dL/dz1 = (W2^T * delta2) * sigmoid'(z1) | |
double[] delta1 = new double[2]; | |
for (int i=0;i<2;++i) { | |
double w = l2.W[0][i]; // weight from hidden i to output | |
double s = a1.data[i]; | |
delta1[i] = (w * delta2) * (s * (1.0 - s)); | |
} | |
// apply gradients: l2 gradients: W2 -= lr * delta2 * a1 | |
double[] a1arr = a1.data; | |
l2.applyGradients(new double[]{delta2}, a1arr, lr); | |
// l1 gradients: W1 -= lr * delta1 * x | |
l1.applyGradients(delta1, x.data, lr); | |
} | |
if (epoch % 2000 == 0) System.out.println("Epoch " + epoch + " loss=" + loss); | |
} | |
// test trianed network | |
System.out.println("Trained XOR network outputs:"); | |
for (int k=0;k<inputs.length;++k) { | |
Tensor x = inputs[k]; | |
Tensor z1 = l1.forward(x); | |
Tensor a1 = Ops.sigmoid(z1); | |
Tensor z2 = l2.forward(a1); | |
Tensor a2 = Ops.sigmoid(z2); | |
System.out.printf("%s -> %.6f\n", Arrays.toString(x.data), a2.data[0]); | |
} | |
// PRint final weights just for fun | |
System.out.println("Final weights l1:"); | |
for (int i=0;i<l1.outDim;++i) System.out.println(Arrays.toString(l1.W[i]) + " b=" + l1.b[i]); | |
System.out.println("Final weights l2:"); | |
for (int i=0;i<l2.outDim;++i) System.out.println(Arrays.toString(l2.W[i]) + " b=" + l2.b[i]); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment