Created
May 22, 2025 13:04
-
-
Save kbastani/eb4594afad40edc6a610657535793484 to your computer and use it in GitHub Desktop.
Java-based Implementation of a MLP for advanced ML use cases which targets ops on CPU/GPU
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package redacted | |
import ai.phylos.neural.models.DeepPerceptron; | |
import ai.phylos.neural.util.PerceptronUtils; | |
import org.apache.commons.lang3.tuple.Pair; | |
import java.text.MessageFormat; | |
import java.util.Arrays; | |
import java.util.List; | |
import java.util.Random; | |
import static ai.phylos.neural.models.activation.ActivationFunction.Functions.*; | |
/** | |
* A classifier demonstrating the use of layer-specific activation functions. | |
* Goal: Classify synthetic time-series segments based on extracted features like | |
* dominant frequency, amplitude, and variability. | |
* - Class 0: Low Frequency, Low Amplitude Sine Wave | |
* - Class 1: Medium Frequency, High Amplitude Sine Wave | |
* - Class 2: High Frequency, Variable Amplitude Noise | |
* | |
* The network takes extracted features as input and uses different activations | |
* in its hidden layers (e.g., ReLU, LeakyReLU, Tanh) to process them. | |
*/ | |
public class TimeSeriesFeatureClassifierDemo { | |
// Simulated Feature Names (Inputs) | |
private static final List<String> FEATURE_KEYS = List.of( | |
"dominant_frequency", // Hz (simulated) | |
"peak_amplitude", // Units | |
"standard_deviation", // Units | |
"zero_crossing_rate" // Crossings per second (simulated) | |
); | |
private static final int INPUT_SIZE = FEATURE_KEYS.size(); // 4 | |
private static final int NUM_CLASSES = 3; | |
public static void main(String[] args) { | |
int epochs = 18000; | |
int numSamples = 2000; | |
int outputSize = NUM_CLASSES; | |
// --- Create the Multi-Layer Neural Network with Layer-Specific Activations --- | |
DeepPerceptron network = DeepPerceptron.factory() | |
.inputSize(INPUT_SIZE) | |
.outputSize(outputSize) | |
.hiddenLayers( // Specify size and activation for each hidden layer | |
Pair.of(16, ReLUActiviation), // Layer 1: ReLU | |
Pair.of(12, LeakyReLUActivation), // Layer 2: Leaky ReLU | |
Pair.of(8, TanhActivation) // Layer 3: Tanh | |
) | |
.outputActivation(SoftmaxActivation) // Output layer activation | |
.learningRate(0.002) | |
.build(); | |
// --- Generate Synthetic Time-Series Feature Data --- | |
double[][] inputs = new double[numSamples][INPUT_SIZE]; | |
double[][] targets = new double[numSamples][outputSize]; | |
Random random = new Random(); | |
int[] classCounts = new int[NUM_CLASSES]; | |
for (int i = 0; i < numSamples; i++) { | |
int assignedClass = random.nextInt(NUM_CLASSES); | |
double freq, amp, stdDev, zcr; | |
switch (assignedClass) { | |
case 0: // Low Freq, Low Amp Sine | |
freq = 1.0 + random.nextDouble() * 2.0; // 1-3 Hz | |
amp = 0.5 + random.nextDouble() * 0.5; // 0.5-1.0 Amplitude | |
stdDev = amp / Math.sqrt(2) + random.nextGaussian() * 0.05; // Approx std dev of sine + noise | |
zcr = 2 * freq + random.nextGaussian() * 0.2; // Approx ZCR | |
break; | |
case 1: // Med Freq, High Amp Sine | |
freq = 5.0 + random.nextDouble() * 5.0; // 5-10 Hz | |
amp = 2.0 + random.nextDouble() * 2.0; // 2.0-4.0 Amplitude | |
stdDev = amp / Math.sqrt(2) + random.nextGaussian() * 0.1; | |
zcr = 2 * freq + random.nextGaussian() * 0.5; | |
break; | |
case 2: // High Freq, Variable Amp Noise (simulate features) | |
default: | |
freq = 15.0 + random.nextDouble() * 10.0; // 15-25 Hz | |
amp = 1.0 + random.nextDouble() * 3.0; // Variable amplitude (peak) | |
stdDev = 0.8 + random.nextDouble() * 1.5; // Higher relative std dev for noise | |
zcr = freq * (1.5 + random.nextDouble()); // Higher ZCR for noisy signal | |
break; | |
} | |
inputs[i][0] = freq; | |
inputs[i][1] = amp; | |
inputs[i][2] = stdDev; | |
inputs[i][3] = zcr; | |
targets[i] = new double[outputSize]; | |
targets[i][assignedClass] = 1.0; | |
classCounts[assignedClass]++; | |
} | |
System.out.println("Generated sample class distribution: " + Arrays.toString(classCounts)); | |
// --- Shuffle, Split, Train, Evaluate --- (Using PerceptronUtils) | |
int[] indices = PerceptronUtils.shuffleIndices(numSamples, random); | |
double[][] shuffledInputs = PerceptronUtils.shuffleArray(inputs, indices); | |
double[][] shuffledTargets = PerceptronUtils.shuffleArray(targets, indices); | |
PerceptronUtils.TrainTestSplit split = PerceptronUtils.splitData(shuffledInputs, shuffledTargets, 0.8); | |
double[][] trainInputs = split.trainInputs(); | |
double[][] trainTargets = split.trainTargets(); | |
double[][] testInputs = split.testInputs(); | |
double[][] testTargets = split.testTargets(); | |
int trainSize = split.trainSize(); | |
// Train | |
System.out.println("--- Training Time-Series Feature Classifier ---"); | |
for (int epoch = 0; epoch < epochs; epoch++) { | |
double totalLoss = PerceptronUtils.trainEpoch(network, trainInputs, trainTargets, random); | |
if (epoch % 2000 == 0) { | |
double averageLoss = totalLoss / trainSize; | |
// Optional: Evaluate test accuracy during training | |
double currentTestAccuracy = PerceptronUtils.evaluateAccuracy(network, testInputs, testTargets); | |
System.out.printf("Epoch %d, Avg Loss: %.5f, Test Acc: %.3f%n", | |
epoch, averageLoss, currentTestAccuracy); | |
} | |
} | |
// Evaluate | |
System.out.println("--- Evaluating Time-Series Feature Classifier ---"); | |
double trainAccuracy = PerceptronUtils.evaluateAccuracy(network, trainInputs, trainTargets); | |
double testAccuracy = PerceptronUtils.evaluateAccuracy(network, testInputs, testTargets); | |
System.out.println("Final Training Accuracy: " + trainAccuracy); | |
System.out.println("Final Testing Accuracy: " + testAccuracy); | |
// Test | |
System.out.println("\n--- Testing Time-Series Feature Classifier ---"); | |
// Input order: freq, amp, stdDev, zcr | |
double[][] testVectors = { | |
{ 2.0, 0.8, 0.6, 4.1}, // Expected Class 0 (LowFq, LowAmp) | |
{ 8.0, 3.5, 2.5, 16.2}, // Expected Class 1 (MedFq, HighAmp) | |
{20.0, 2.5, 1.8, 45.0}, // Expected Class 2 (HighFq, Noisy) | |
{ 6.0, 1.0, 0.7, 12.5}, // Ambiguous? Closer to LowAmp (Exp 0 or 1?) -> Let network decide | |
{12.0, 4.0, 3.0, 25.0} // Ambiguous? Closer to MedFq/HighAmp (Exp 1 or 2?) -> Let network decide | |
}; | |
int[] expectedClasses = {0, 1, 2, -1, -1}; // -1 for ambiguous cases | |
String[] classNames = {"LowFq_LowAmp", "MedFq_HighAmp", "HighFq_Noise"}; | |
for (int i = 0; i < testVectors.length; i++) { | |
double[] vector = testVectors[i]; | |
double[] output = network.forward(vector); | |
int pred = PerceptronUtils.argmax(output); | |
String expectedStr = (expectedClasses[i] == -1) ? "Ambiguous" : String.valueOf(expectedClasses[i]); | |
System.out.println(MessageFormat.format("Input Features: {0}, Softmax={1}, Predicted: {2} ({3}), Expected: {4}", | |
Arrays.toString(vector), PerceptronUtils.formatArray(output,3), pred, classNames[pred], expectedStr)); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment