Created
April 22, 2025 23:50
-
-
Save kujirahand/b3c6a5b40310fb88964116fbe2665d9d to your computer and use it in GitHub Desktop.
Rustでランダムフォレストを実装しよう --- MNISTの判定
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use lazyrand; | |
use mnist_reader::MnistReader; | |
/// 決定木のノードを表します。 | |
enum Node { | |
Leaf { prediction: u8 }, | |
Decision { | |
feature_index: usize, | |
threshold: f32, | |
left: Box<Node>, | |
right: Box<Node>, | |
}, | |
} | |
/// 単純な決定木クラス | |
pub struct DecisionTree { | |
max_depth: usize, | |
min_samples_split: usize, | |
max_features: usize, | |
root: Option<Box<Node>>, | |
} | |
impl DecisionTree { | |
/// 新しい決定木を作成 | |
pub fn new(max_depth: usize, min_samples_split: usize, max_features: usize) -> Self { | |
DecisionTree { max_depth, min_samples_split, max_features, root: None } | |
} | |
/// 学習 | |
pub fn train(&mut self, data: &[Vec<f32>], labels: &[u8]) { | |
self.root = Some(Box::new(self.build_tree(data, labels, 0))); | |
} | |
/// 再帰的に決定木を構築 | |
fn build_tree(&self, data: &[Vec<f32>], labels: &[u8], depth: usize) -> Node { | |
// 終了条件:深さ or サンプル数 | |
if depth >= self.max_depth || labels.len() < self.min_samples_split { | |
return Node::Leaf { prediction: majority_label(labels) }; | |
} | |
// 全て同じクラスなら分割不要 | |
if labels.iter().all(|&x| x == labels[0]) { | |
return Node::Leaf { prediction: labels[0] }; | |
} | |
// 最良分割を探す | |
if let Some((feat, thr)) = best_split(data, labels, self.max_features, self.min_samples_split) { | |
let (left_data, left_labels, right_data, right_labels) = split(data, labels, feat, thr); | |
let left_node = self.build_tree(&left_data, &left_labels, depth + 1); | |
let right_node = self.build_tree(&right_data, &right_labels, depth + 1); | |
Node::Decision { | |
feature_index: feat, | |
threshold: thr, | |
left: Box::new(left_node), | |
right: Box::new(right_node), | |
} | |
} else { | |
Node::Leaf { prediction: majority_label(labels) } | |
} | |
} | |
/// 1サンプルを予測 | |
pub fn predict(&self, sample: &[f32]) -> u8 { | |
let mut node = self.root.as_ref().unwrap(); | |
loop { | |
match **node { | |
Node::Leaf { prediction } => return prediction, | |
Node::Decision { feature_index, threshold, ref left, ref right } => { | |
if sample[feature_index] < threshold { node = left; } else { node = right; } | |
} | |
} | |
} | |
} | |
} | |
/// ランダムフォレスト本体 | |
pub struct RandomForest { | |
trees: Vec<DecisionTree>, | |
} | |
impl RandomForest { | |
/// readerのtrain_dataで学習し、テストデータで評価できるようにします | |
pub fn train(reader: &MnistReader, n_trees: usize, max_depth: usize, min_samples_split: usize) -> Self { | |
let n_features = reader.train_data[0].len(); | |
// sqrt(特徴量数)を使うのが一般的 | |
let max_features = (n_features as f32).sqrt() as usize; | |
let mut trees = Vec::with_capacity(n_trees); | |
for i in 0..n_trees { | |
println!("+ training tree...{}/{}", i+1, n_trees); | |
// ブートストラップサンプル | |
let indices: Vec<usize> = (0..reader.train_data.len()) | |
.map(|_| lazyrand::rand_usize() % reader.train_data.len()) | |
.collect(); | |
let data: Vec<Vec<f32>> = indices.iter().map(|&i| reader.train_data[i].clone()).collect(); | |
let labels: Vec<u8> = indices.iter().map(|&i| reader.train_labels[i]).collect(); | |
let mut tree = DecisionTree::new(max_depth, min_samples_split, max_features); | |
tree.train(&data, &labels); | |
trees.push(tree); | |
} | |
RandomForest { trees } | |
} | |
/// 1サンプルを予測 | |
pub fn predict(&self, sample: &[f32]) -> u8 { | |
let mut votes = HashMap::new(); | |
for tree in &self.trees { | |
let pred = tree.predict(sample); | |
*votes.entry(pred).or_insert(0) += 1; | |
} | |
votes.into_iter().max_by_key(|&(_, c)| c).map(|(cls, _)| cls).unwrap_or(0) | |
} | |
/// テストデータで精度を計算 | |
pub fn evaluate(&self, reader: &MnistReader) -> f32 { | |
let mut correct = 0; | |
for (sample, &label) in reader.test_data.iter().zip(reader.test_labels.iter()) { | |
if self.predict(sample) == label { correct += 1; } | |
} | |
correct as f32 / reader.test_labels.len() as f32 | |
} | |
} | |
/// Gini impurity を計算 | |
fn gini_impurity(labels: &[u8]) -> f32 { | |
let mut counts = HashMap::new(); | |
for &lbl in labels { *counts.entry(lbl).or_insert(0) += 1; } | |
let n = labels.len() as f32; | |
counts.values().map(|&c| { | |
let p = c as f32 / n; | |
p * (1.0 - p) | |
}).sum() | |
} | |
/// 最多数クラス | |
fn majority_label(labels: &[u8]) -> u8 { | |
let mut counts = HashMap::new(); | |
for &lbl in labels { *counts.entry(lbl).or_insert(0) += 1; } | |
counts.into_iter().max_by_key(|&(_, c)| c).unwrap().0 | |
} | |
/// データとラベルを閾値で分割 | |
fn split( | |
data: &[Vec<f32>], | |
labels: &[u8], | |
feat: usize, | |
thr: f32, | |
) -> (Vec<Vec<f32>>, Vec<u8>, Vec<Vec<f32>>, Vec<u8>) { | |
let mut left_d = Vec::new(); | |
let mut left_l = Vec::new(); | |
let mut right_d = Vec::new(); | |
let mut right_l = Vec::new(); | |
for (sample, &lbl) in data.iter().zip(labels.iter()) { | |
if sample[feat] < thr { | |
left_d.push(sample.clone()); | |
left_l.push(lbl); | |
} else { | |
right_d.push(sample.clone()); | |
right_l.push(lbl); | |
} | |
} | |
(left_d, left_l, right_d, right_l) | |
} | |
fn choose_multiple(indices: &mut Vec<usize>, n: usize) -> Vec<usize> { | |
lazyrand::shuffle(indices); | |
indices.iter().take(n).cloned().collect() | |
} | |
/// 最良分割をランダムに探索 | |
fn best_split( | |
data: &[Vec<f32>], | |
labels: &[u8], | |
max_features: usize, | |
min_samples_split: usize, | |
) -> Option<(usize, f32)> { | |
let n = data.len(); | |
if n < min_samples_split { return None; } | |
let n_features = data[0].len(); | |
let mut feats: Vec<usize> = (0..n_features).collect(); | |
let mut best = None; | |
let mut best_gini = f32::MAX; | |
for feat in choose_multiple(&mut feats, max_features) { | |
let vals: Vec<f32> = data.iter().map(|s| s[feat]).collect(); | |
let &min = vals.iter().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); | |
let &max = vals.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); | |
if (max - min).abs() < f32::EPSILON { continue; } | |
let thr = (lazyrand::rand_f64() as f32) * (max - min) + min; | |
let mut left = Vec::new(); | |
let mut right = Vec::new(); | |
for (v, &lbl) in vals.iter().zip(labels.iter()) { | |
if *v < thr { left.push(lbl); } else { right.push(lbl); } | |
} | |
if left.is_empty() || right.is_empty() { continue; } | |
let gini = (left.len() as f32 / n as f32) * gini_impurity(&left) | |
+ (right.len() as f32 / n as f32) * gini_impurity(&right); | |
if gini < best_gini { | |
best_gini = gini; | |
best = Some((feat, thr)); | |
} | |
} | |
best | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment