Skip to content

Instantly share code, notes, and snippets.

@kujirahand
Created April 22, 2025 23:50
Show Gist options
  • Save kujirahand/b3c6a5b40310fb88964116fbe2665d9d to your computer and use it in GitHub Desktop.
Save kujirahand/b3c6a5b40310fb88964116fbe2665d9d to your computer and use it in GitHub Desktop.
Rustでランダムフォレストを実装しよう --- MNISTの判定
use std::collections::HashMap;
use lazyrand;
use mnist_reader::MnistReader;
/// 決定木のノードを表します。
enum Node {
Leaf { prediction: u8 },
Decision {
feature_index: usize,
threshold: f32,
left: Box<Node>,
right: Box<Node>,
},
}
/// 単純な決定木クラス
pub struct DecisionTree {
max_depth: usize,
min_samples_split: usize,
max_features: usize,
root: Option<Box<Node>>,
}
impl DecisionTree {
/// 新しい決定木を作成
pub fn new(max_depth: usize, min_samples_split: usize, max_features: usize) -> Self {
DecisionTree { max_depth, min_samples_split, max_features, root: None }
}
/// 学習
pub fn train(&mut self, data: &[Vec<f32>], labels: &[u8]) {
self.root = Some(Box::new(self.build_tree(data, labels, 0)));
}
/// 再帰的に決定木を構築
fn build_tree(&self, data: &[Vec<f32>], labels: &[u8], depth: usize) -> Node {
// 終了条件:深さ or サンプル数
if depth >= self.max_depth || labels.len() < self.min_samples_split {
return Node::Leaf { prediction: majority_label(labels) };
}
// 全て同じクラスなら分割不要
if labels.iter().all(|&x| x == labels[0]) {
return Node::Leaf { prediction: labels[0] };
}
// 最良分割を探す
if let Some((feat, thr)) = best_split(data, labels, self.max_features, self.min_samples_split) {
let (left_data, left_labels, right_data, right_labels) = split(data, labels, feat, thr);
let left_node = self.build_tree(&left_data, &left_labels, depth + 1);
let right_node = self.build_tree(&right_data, &right_labels, depth + 1);
Node::Decision {
feature_index: feat,
threshold: thr,
left: Box::new(left_node),
right: Box::new(right_node),
}
} else {
Node::Leaf { prediction: majority_label(labels) }
}
}
/// 1サンプルを予測
pub fn predict(&self, sample: &[f32]) -> u8 {
let mut node = self.root.as_ref().unwrap();
loop {
match **node {
Node::Leaf { prediction } => return prediction,
Node::Decision { feature_index, threshold, ref left, ref right } => {
if sample[feature_index] < threshold { node = left; } else { node = right; }
}
}
}
}
}
/// ランダムフォレスト本体
pub struct RandomForest {
trees: Vec<DecisionTree>,
}
impl RandomForest {
/// readerのtrain_dataで学習し、テストデータで評価できるようにします
pub fn train(reader: &MnistReader, n_trees: usize, max_depth: usize, min_samples_split: usize) -> Self {
let n_features = reader.train_data[0].len();
// sqrt(特徴量数)を使うのが一般的
let max_features = (n_features as f32).sqrt() as usize;
let mut trees = Vec::with_capacity(n_trees);
for i in 0..n_trees {
println!("+ training tree...{}/{}", i+1, n_trees);
// ブートストラップサンプル
let indices: Vec<usize> = (0..reader.train_data.len())
.map(|_| lazyrand::rand_usize() % reader.train_data.len())
.collect();
let data: Vec<Vec<f32>> = indices.iter().map(|&i| reader.train_data[i].clone()).collect();
let labels: Vec<u8> = indices.iter().map(|&i| reader.train_labels[i]).collect();
let mut tree = DecisionTree::new(max_depth, min_samples_split, max_features);
tree.train(&data, &labels);
trees.push(tree);
}
RandomForest { trees }
}
/// 1サンプルを予測
pub fn predict(&self, sample: &[f32]) -> u8 {
let mut votes = HashMap::new();
for tree in &self.trees {
let pred = tree.predict(sample);
*votes.entry(pred).or_insert(0) += 1;
}
votes.into_iter().max_by_key(|&(_, c)| c).map(|(cls, _)| cls).unwrap_or(0)
}
/// テストデータで精度を計算
pub fn evaluate(&self, reader: &MnistReader) -> f32 {
let mut correct = 0;
for (sample, &label) in reader.test_data.iter().zip(reader.test_labels.iter()) {
if self.predict(sample) == label { correct += 1; }
}
correct as f32 / reader.test_labels.len() as f32
}
}
/// Gini impurity を計算
fn gini_impurity(labels: &[u8]) -> f32 {
let mut counts = HashMap::new();
for &lbl in labels { *counts.entry(lbl).or_insert(0) += 1; }
let n = labels.len() as f32;
counts.values().map(|&c| {
let p = c as f32 / n;
p * (1.0 - p)
}).sum()
}
/// 最多数クラス
fn majority_label(labels: &[u8]) -> u8 {
let mut counts = HashMap::new();
for &lbl in labels { *counts.entry(lbl).or_insert(0) += 1; }
counts.into_iter().max_by_key(|&(_, c)| c).unwrap().0
}
/// データとラベルを閾値で分割
fn split(
data: &[Vec<f32>],
labels: &[u8],
feat: usize,
thr: f32,
) -> (Vec<Vec<f32>>, Vec<u8>, Vec<Vec<f32>>, Vec<u8>) {
let mut left_d = Vec::new();
let mut left_l = Vec::new();
let mut right_d = Vec::new();
let mut right_l = Vec::new();
for (sample, &lbl) in data.iter().zip(labels.iter()) {
if sample[feat] < thr {
left_d.push(sample.clone());
left_l.push(lbl);
} else {
right_d.push(sample.clone());
right_l.push(lbl);
}
}
(left_d, left_l, right_d, right_l)
}
fn choose_multiple(indices: &mut Vec<usize>, n: usize) -> Vec<usize> {
lazyrand::shuffle(indices);
indices.iter().take(n).cloned().collect()
}
/// 最良分割をランダムに探索
fn best_split(
data: &[Vec<f32>],
labels: &[u8],
max_features: usize,
min_samples_split: usize,
) -> Option<(usize, f32)> {
let n = data.len();
if n < min_samples_split { return None; }
let n_features = data[0].len();
let mut feats: Vec<usize> = (0..n_features).collect();
let mut best = None;
let mut best_gini = f32::MAX;
for feat in choose_multiple(&mut feats, max_features) {
let vals: Vec<f32> = data.iter().map(|s| s[feat]).collect();
let &min = vals.iter().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
let &max = vals.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
if (max - min).abs() < f32::EPSILON { continue; }
let thr = (lazyrand::rand_f64() as f32) * (max - min) + min;
let mut left = Vec::new();
let mut right = Vec::new();
for (v, &lbl) in vals.iter().zip(labels.iter()) {
if *v < thr { left.push(lbl); } else { right.push(lbl); }
}
if left.is_empty() || right.is_empty() { continue; }
let gini = (left.len() as f32 / n as f32) * gini_impurity(&left)
+ (right.len() as f32 / n as f32) * gini_impurity(&right);
if gini < best_gini {
best_gini = gini;
best = Some((feat, thr));
}
}
best
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment