Skip to content

Instantly share code, notes, and snippets.

@antoineMoPa
Last active January 22, 2025 14:41
Show Gist options
  • Save antoineMoPa/3b7f501d926d1f2648475949b0ccffc7 to your computer and use it in GitHub Desktop.
Save antoineMoPa/3b7f501d926d1f2648475949b0ccffc7 to your computer and use it in GitHub Desktop.
Minimal example neural net using huggingface's rust neural net library: candle

Rust neural networks with candle

This is a complete example of using candle to train a simple neural network on a XOR dataset.

Why build this?

Hopefully, this helps some folks trying candle for the first time!

Breaking it down into baby steps

Varmaps

Varmaps are an essential candle concept.

Without varmaps, your model is static, your weights will not get updated and your neural net will always give the same output.

So, we build varmaps and pass it around when creating layers.

// initializing varmaps
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F64, &Device::Cpu);
// Using varmaps in a layer
let fc1 = nn::linear(8, 16,vb.pp("fc1"))?;

Defining our model and its layers

struct Mlp {
    fc1: nn::Linear,
    act: candle_nn::Activation,
    fc2: nn::Linear,
}

impl Mlp {
    fn new(vb: VarBuilder) -> Result<Self, candle_core::Error> {
        let fc1 = nn::linear(8, 16,vb.pp("fc1"))?; // Layer 1
        let act = candle_nn::activation::Activation::Relu;
        let fc2 = nn::linear(16, 4,vb.pp("fc2"))?; // layer 2
        
        Ok(Self { fc1, fc2, act })
    }
}

Forward pass

This indicates how to pass data in our different layers:

fn forward(&self, input: &Tensor) -> Result<Tensor, candle_core::Error> {
    input
        .apply(&self.fc1)?
        .apply(&self.act)?
        .apply(&self.fc2)?
        .apply(&nn::activation::Activation::Sigmoid) // Adding a sigmoid helped training a lot
}

Preparing test data

Preparing a xor data set. The input tensor is of size 8 and contains 4 pair of numbers that we will perform XOR on, resulting in 4 results in the output tensor.

    let mut inputs = Vec::new();
    let mut targets = Vec::new();

    for _ in 0..200 {
        // Generate random sample
        let mut sample = Vec::new();
        for _ in 0..8 {
            sample.push(rand::thread_rng().gen_range(0..2) as f64);
        }
        let mut sample_result = Vec::new();
        // XOR pairs of values
        for i in 0..4 {
            let a = sample[i * 2] as i32;
            let b = sample[(i * 2) + 1] as i32;
            let result = a ^ b;

            sample_result.push(result as f64);
        }

        inputs.push(Tensor::new(&[
            sample[0],
            sample[1],
            sample[2],
            sample[3],
            sample[4],
            sample[5],
            sample[6],
            sample[7],
        ], &device)?);
        targets.push(Tensor::new(&[
            sample_result[0],
            sample_result[1],
            sample_result[2],
            sample_result[3],
        ], &device)?);
    }

    let inputs = Tensor::stack(&inputs, 0)?;
    let targets = Tensor::stack(&targets, 0)?;

Creating the optimizer

let params = ParamsAdamW {
    lr: 0.2,
    ..Default::default()
};
let mut optimizer = candle_nn::AdamW::new(varmap.all_vars(), params)?;

Training loop

for epoch in 0..200 {
    // Forward pass
    let predictions = model.forward(&inputs)?;

    // Compute loss
    // Can be fun to try both of these!
    let loss = (&predictions - &targets)?.sqr()?.mean_all()?;
    // let loss = nn::loss::binary_cross_entropy_with_logit(&predictions, &targets)?;

    // Backpropagation
    optimizer.backward_step(&loss)?;

    if epoch % 10 == 0 {
        println!("Epoch {}: Loss = {:?}", epoch, loss);
    }
}

Complete code

Putting it all together and adding a test.

use rand::Rng;

use candle_core::{Device, Tensor, DType};
use candle_nn as nn;
use nn::{VarMap, Optimizer, VarBuilder, ParamsAdamW};

struct Mlp {
    fc1: nn::Linear,
    act: candle_nn::Activation,
    fc2: nn::Linear,
}

impl Mlp {
    fn new(vb: VarBuilder) -> Result<Self, candle_core::Error> {
        let fc1 = nn::linear(8, 16,vb.pp("fc1"))?;
        let fc2 = nn::linear(16, 4,vb.pp("fc2"))?;

        let act = candle_nn::activation::Activation::Relu;
        Ok(Self { fc1, fc2, act })
    }

    fn forward(&self, input: &Tensor) -> Result<Tensor, candle_core::Error> {
        input
            .apply(&self.fc1)?
            .apply(&self.act)?
            .apply(&self.fc2)?
            .apply(&nn::activation::Activation::Sigmoid)
    }
}

fn build_and_train_model() -> Result<Mlp, candle_core::Error> {
    // Use the default device (CPU in this case)
    let device = Device::Cpu;

    // Define training data for XOR
    let mut inputs = Vec::new();
    let mut targets = Vec::new();

    for _ in 0..200 {
        // Generate random sample
        let mut sample = Vec::new();
        for _ in 0..8 {
            sample.push(rand::thread_rng().gen_range(0..2) as f64);
        }
        let mut sample_result = Vec::new();
        // XOR pairs of values
        for i in 0..4 {
            let a = sample[i * 2] as i32;
            let b = sample[(i * 2) + 1] as i32;
            let result = a ^ b;

            sample_result.push(result as f64);
        }

        inputs.push(Tensor::new(&[
            sample[0],
            sample[1],
            sample[2],
            sample[3],
            sample[4],
            sample[5],
            sample[6],
            sample[7],
        ], &device)?);
        targets.push(Tensor::new(&[
            sample_result[0],
            sample_result[1],
            sample_result[2],
            sample_result[3],
        ], &device)?);
    }

    let inputs = Tensor::stack(&inputs, 0)?;
    let targets = Tensor::stack(&targets, 0)?;

    // Create Varbuilder
    let varmap = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap, DType::F64, &Device::Cpu);

    // Create the XORNet model
    let model = Mlp::new(vb)?;

    // Optimizer settings
    let params = ParamsAdamW {
        lr: 0.2,
        ..Default::default()
    };
    let mut optimizer = candle_nn::AdamW::new(varmap.all_vars(), params)?;

    // Training loop
    for epoch in 0..200 {
        // Forward pass
        let predictions = model.forward(&inputs)?;

        // Compute loss
        // Can be fun to try both of these!
        let loss = (&predictions - &targets)?.sqr()?.mean_all()?;
        // let loss = nn::loss::binary_cross_entropy_with_logit(&predictions, &targets)?;

        // Backpropagation
        optimizer.backward_step(&loss)?;

        if epoch % 10 == 0 {
            println!("Epoch {}: Loss = {:?}", epoch, loss);
        }
    }

    Ok(model)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_basic_candle_tensor_stuff() -> Result<(), Box<dyn std::error::Error>> {
        let device = Device::Cpu;

        let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
        let b = Tensor::randn(0f32, 1., (3, 4), &device)?;

        let c = a.matmul(&b)?;

        let dims = c.shape().dims();

        assert_eq!(dims[0], 2);
        assert_eq!(dims[1], 4);

        Ok(())
    }

    #[test]
    fn test_candle_xor() -> Result<(), Box<dyn std::error::Error>> {
        let device = Device::Cpu;
        let model = build_and_train_model()?;

        let inputs = Tensor::new(&[[0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0]], &device)?;
        let targets = Tensor::new(&[[0.0, 1.0, 1.0, 0.0]], &device)?;
        let test_preds = model.forward(&inputs)?;

        let diff: f64 = (test_preds - &targets)?.sum_all()?.to_vec0()?;
        assert!(diff < 0.03);

        Ok(())
    }
}

Cargo.toml

You can just cargo add candle_core, candle_nn and rand or use this Cargo.toml.

[package]
name = "candle_test"
version = "0.1.0"
edition = "2021"

[dependencies]
candle-core = "0.8.2"
candle-nn = "0.8.2"
rand = "0.8.5"

Running

cargo test

Running (verbose)

cargo test -- --nocapture
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment