Last active
September 15, 2018 19:06
-
-
Save tehZevo/0f4ec45fb900c207491a7032b9d770ba to your computer and use it in GitHub Desktop.
An example solving XOR problem with tfjs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
var tf = require("@tensorflow/tfjs"); | |
//code modified from: | |
//https://medium.com/tensorflow/a-gentle-introduction-to-tensorflow-js-dba2e5257702 | |
//define our inputs (combinations of 2 bits, represented as 0s and 1s) | |
//https://js.tensorflow.org/api/0.12.0/#tensor | |
var xs = tf.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]); | |
//define our outputs (xor operation, a simple non-linear problem) | |
//[0, 0] -> [0], [0, 1] -> [1], etc | |
var ys = tf.tensor([[0], [1], [1], [0]]); | |
//create a "sequential" model | |
//https://js.tensorflow.org/api/0.12.0/#sequential | |
var model = tf.sequential(); | |
//add a dense/"fully-connected" layer to the model that takes inputs of size 2, | |
//contains 8 neurons, and uses "tanh" activation function | |
//https://js.tensorflow.org/api/0.12.0/#layers.dense | |
//see https://js.tensorflow.org/api/0.12.0/#layers.activation for other activations | |
model.add(tf.layers.dense({units: 8, inputDim: 2, activation: 'tanh'})); | |
//add a second dense layer with 1 neuron and "sigmoid" activation | |
model.add(tf.layers.dense({units: 1, activation: 'sigmoid'})); | |
//compile our model using "adam" optimizer, binary crossentropy loss, | |
//and a learning rate of 0.1 | |
//https://js.tensorflow.org/api/0.12.0/#class:Model | |
model.compile({optimizer: 'adam', loss: 'binaryCrossentropy', lr: 0.3}); | |
//train the model for 5000 epochs (complete passes over the dataset), | |
//using a batch size of 4 (4 input/output pairs per iteration) | |
//https://towardsdatascience.com/epoch-vs-iterations-vs-batch-size-4dfb9c7ce9c9 | |
model.fit(xs, ys, { | |
batchSize: 4, | |
epochs: 1000, | |
callbacks: { | |
//after every epoch, print the current loss | |
onEpochEnd: async (epoch, log) => { | |
console.log(`Epoch ${epoch}: loss = ${log.loss}`); | |
} | |
} | |
}).then(() => | |
{ | |
//after training, predict answers for our original input data | |
model.predict(xs).print(); | |
//output should be close to [[0], [1], [1], [0]] | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment