Skip to content

Instantly share code, notes, and snippets.

@freakynit
Created September 5, 2025 05:31
Show Gist options
  • Save freakynit/192df36eb6139b1a2b481a642ac5b01a to your computer and use it in GitHub Desktop.
Save freakynit/192df36eb6139b1a2b481a642ac5b01a to your computer and use it in GitHub Desktop.
A Java implementation of 12 RISC-style primitive operations and a 2-2-1 neural network to learn XOR using backpropagation, inspired by LuminalAI, for educational purposes.
import java.util.Arrays;
import java.util.Random;
/**
* A small RISC-style primitive operation framework (inpsired by LuminalAI),
* implementing 12 primitive ops and a demo network.
*
* Inspiration: https://github.com/luminal-ai/luminal
*
* Ops:
* ---
* Unary - Log2, Exp2, Sin, Sqrt, Recip
* Binary - Add, Mul, Mod, LessThan
* Other - SumReduce, MaxReduce, Contiguous
*
* Demo:
* ----
* Trains a small 2-2-1 network to learn XOR network and executes it.
*/
public class Main {
public static class Tensor {
public final double[] data;
public Tensor(int n) { this.data = new double[n]; }
public Tensor(double... arr) { this.data = arr.clone(); }
public int len() { return data.length; }
public double get(int i) { return data[i]; }
public void set(int i, double v) { data[i] = v; }
@Override public String toString() { return Arrays.toString(data); }
}
public static class Ops {
// Unary (elementwise)
public static Tensor log2(Tensor a) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = Math.log(a.data[i]) / Math.log(2.0);
return out;
}
public static Tensor exp2(Tensor a) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = Math.pow(2.0, a.data[i]);
return out;
}
public static Tensor sin(Tensor a) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = Math.sin(a.data[i]);
return out;
}
public static Tensor sqrt(Tensor a) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = Math.sqrt(a.data[i]);
return out;
}
public static Tensor recip(Tensor a) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = 1.0 / a.data[i];
return out;
}
// Binary (elementwise) - supports tensor-tensor and tensor-scalar
public static Tensor add(Tensor a, Tensor b) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] + b.data[i];
return out;
}
public static Tensor add(Tensor a, double scalar) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] + scalar;
return out;
}
public static Tensor mul(Tensor a, Tensor b) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] * b.data[i];
return out;
}
public static Tensor mul(Tensor a, double scalar) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] * scalar;
return out;
}
public static Tensor mod(Tensor a, Tensor b) {
if (a.len() != b.len()) throw new IllegalArgumentException("mod: length mismatch");
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] % b.data[i];
return out;
}
public static Tensor mod(Tensor a, double scalar) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = a.data[i] % scalar;
return out;
}
// LessThan returns 1.0 where a < b, else 0.0
public static Tensor lessThan(Tensor a, Tensor b) {
if (a.len() != b.len()) throw new IllegalArgumentException("lessThan: length mismatch");
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = (a.data[i] < b.data[i]) ? 1.0 : 0.0;
return out;
}
public static Tensor lessThan(Tensor a, double scalar) {
Tensor out = new Tensor(a.len());
for (int i=0;i<a.len();++i) out.data[i] = (a.data[i] < scalar) ? 1.0 : 0.0;
return out;
}
// Other
// SumReduce: sum of a tensor -> scalar
public static double sumReduce(Tensor a) {
double s=0.0; for (double v : a.data) s+=v; return s;
}
// MaxReduce: max element
public static double maxReduce(Tensor a) {
double m = Double.NEGATIVE_INFINITY; for (double v : a.data) if (v>m) m=v; return m;
}
// Contiguous: return a contiguous copy of the tensor
public static Tensor contiguous(Tensor a) {
return new Tensor(a.data);
}
// Convenience helpers built from primitives
public static double dot(Tensor a, Tensor b) {
return sumReduce(mul(a,b));
}
// Elementwise sigmoid built from primitives (uses exp2 and recip)
// sigmoid(x) = 1 / (1 + e^{-x})
// e^{-x} = 2^{(-x)/ln2} so we use exp2 and mul
public static Tensor sigmoid(Tensor x) {
// sigmoid implemented from exp2 + recip primitives
final double invLn2 = 1.0 / Math.log(2.0);
Tensor neg = mul(x, -1.0);
Tensor scaled = mul(neg, invLn2);
Tensor eMinusX = exp2(scaled);
Tensor denom = add(eMinusX, 1.0);
return recip(denom);
}
// helper to compute elementwise sigmoid derivative: s*(1-s)
public static Tensor sigmoidGrad(Tensor s) {
Tensor oneMinus = new Tensor(s.len());
for (int i=0;i<s.len();++i) oneMinus.data[i] = 1.0 - s.data[i];
return mul(s, oneMinus);
}
// Simple ReLU built from lessThan and mul + add (relu(x) = x * (x>0))
// using lessThan we produce negative mask then compute 1 - negativeMask to get positiveMask
public static Tensor relu(Tensor x) {
// negativeMask = lessThan(x, 0) -> 1.0 for negatives
Tensor negMask = lessThan(x, 0.0);
// positiveMask = 1 - negMask => implement as add with scalar 1 then add negMask*-1
// simpler: create tensor of ones and subtract mask by multiplying by -1 and adding
Tensor ones = new Tensor(new double[x.len()]);
for (int i=0;i<ones.len();++i) ones.data[i] = 1.0;
// posMask = ones + (-1)*negMask
Tensor posMask = add(ones, mul(negMask, -1.0));
return mul(x, posMask);
}
}
// Dense layer with a few utils for forward/backprop updates
public static class DenseLayer {
// weights => outDim x inDim
public final double[][] W;
public final double[] b;
public final int inDim, outDim;
public DenseLayer(int inDim, int outDim, Random rng) {
this.inDim = inDim; this.outDim = outDim;
W = new double[outDim][inDim];
b = new double[outDim];
for (int i=0;i<outDim;++i) {
for (int j=0;j<inDim;++j) W[i][j] = rng.nextGaussian() * 0.5;
b[i] = 0.0;
}
}
public Tensor forward(Tensor x) {
Tensor out = new Tensor(outDim);
for (int i=0;i<outDim;++i) {
double s=0.0;
for (int j=0;j<inDim;++j) s += W[i][j]*x.data[j];
out.data[i] = s + b[i];
}
return out;
}
// Apply gradient update given upstream delta for this layer's outputs.
// `deltaOut` is length outDim, and `prevAct` is the activations of previous layer (length inDim)
public void applyGradients(double[] deltaOut, double[] prevAct, double lr) {
for (int i=0;i<outDim;++i) {
for (int j=0;j<inDim;++j) {
W[i][j] -= lr * deltaOut[i] * prevAct[j];
}
b[i] -= lr * deltaOut[i];
}
}
}
public static void main(String[] args) {
Random rng = new Random(1234);
// XOR dataset (2 inputs -> 1 output)
Tensor[] inputs = {
new Tensor(0.0,0.0), new Tensor(0.0,1.0), new Tensor(1.0,0.0), new Tensor(1.0,1.0)
};
double[] targets = {0.0, 1.0, 1.0, 0.0};
// network `2 -> 2 -> 1`
DenseLayer l1 = new DenseLayer(2, 2, rng);
DenseLayer l2 = new DenseLayer(2, 1, rng);
double lr = 0.5;
int epochs = 20000;
for (int epoch=0; epoch<epochs; ++epoch) {
double loss = 0.0;
// Simple SGD over dataset
for (int k=0;k<inputs.length;++k) {
Tensor x = inputs[k];
double t = targets[k];
// forward pass
Tensor z1 = l1.forward(x);
Tensor a1 = Ops.sigmoid(z1);
Tensor z2 = l2.forward(a1);
Tensor a2 = Ops.sigmoid(z2); // network output (length 1)
double y = a2.data[0];
double err = y - t;
loss += 0.5 * err * err;
// backward: output delta: dL/dz2 = (y - t) * sigmoid'(z2) ; sigmoid'(z2)=y*(1-y)
double delta2 = err * (y * (1.0 - y)); // scalar
// hidden deltas: dL/dz1 = (W2^T * delta2) * sigmoid'(z1)
double[] delta1 = new double[2];
for (int i=0;i<2;++i) {
double w = l2.W[0][i]; // weight from hidden i to output
double s = a1.data[i];
delta1[i] = (w * delta2) * (s * (1.0 - s));
}
// apply gradients: l2 gradients: W2 -= lr * delta2 * a1
double[] a1arr = a1.data;
l2.applyGradients(new double[]{delta2}, a1arr, lr);
// l1 gradients: W1 -= lr * delta1 * x
l1.applyGradients(delta1, x.data, lr);
}
if (epoch % 2000 == 0) System.out.println("Epoch " + epoch + " loss=" + loss);
}
// test trianed network
System.out.println("Trained XOR network outputs:");
for (int k=0;k<inputs.length;++k) {
Tensor x = inputs[k];
Tensor z1 = l1.forward(x);
Tensor a1 = Ops.sigmoid(z1);
Tensor z2 = l2.forward(a1);
Tensor a2 = Ops.sigmoid(z2);
System.out.printf("%s -> %.6f\n", Arrays.toString(x.data), a2.data[0]);
}
// PRint final weights just for fun
System.out.println("Final weights l1:");
for (int i=0;i<l1.outDim;++i) System.out.println(Arrays.toString(l1.W[i]) + " b=" + l1.b[i]);
System.out.println("Final weights l2:");
for (int i=0;i<l2.outDim;++i) System.out.println(Arrays.toString(l2.W[i]) + " b=" + l2.b[i]);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment