Skip to content

Instantly share code, notes, and snippets.

@JosephCatrambone
Created August 11, 2024 00:07
Show Gist options
  • Save JosephCatrambone/04c18138c560e6588cfc76ba2eee597f to your computer and use it in GitHub Desktop.
Save JosephCatrambone/04c18138c560e6588cfc76ba2eee597f to your computer and use it in GitHub Desktop.
RTree - Single-File n-Dimensional Spatial Index Tree in Rust
/*
RTree.rs
Author: Joseph Catrambone <me at josephcatrambone.com>
License: MIT / GPL at User's Discretion
Description: A simple high-level n-dimensional tree that's made to work with any indexable data type.
*/
use std::mem;
use std::ops::Index;
const SPLIT_POINT: usize = 16;
const BIG_NUMBER: f32 = 1e32;
/*
Quick plot for average search time by split point:
Search Dist: 0.01
Num Points: 100,000
List Search Time: 164641ns
|split_point|avg insert time|avg search time in nanoseconds|
|-----------|---|
|4|9827|22|
|8|8872|22|
|16|8191|89|
|32|7516|309|
|64|7064|1253|
|128|6463|3545|
|256|5703|6944|
|512|5042|13594|
|1024|4518|27402|
|2048|3928|48262|
Search Dist: None
Num Points: 100,000
List Search Time: 182076ns
|split_point|avg insert time|avg search time in nanoseconds|
|-----------|---|
|4|9858|1681|
|8|8925|1444|
|16|8078|795|
|32|7539|1217|
|64|6861|2820|
*/
//trait Pointlike: Clone + Index<usize, Output=f32> + IntoIterator<Item=f32> {}
/// A badly implemented RTree for finding nearest points.
pub struct RTree<T> where T: Index<usize, Output=f32> + IntoIterator<Item=f32> {
num_children: usize,
lower_bounds: Vec<f32>,
upper_bounds: Vec<f32>,
children: Vec<RTree<T>>, // Exists only at non-leaf nodes.
points: Vec<T>, // Exists only at leaf nodes.
dimensions: usize,
distance_fn: fn(&T, &T, usize) -> f32,
}
fn euclidean_distance<T>(a: &T, b: &T, dims: usize) -> f32 where T: Index<usize, Output=f32> {
//a.iter().zip(b).map(|(d1, d2)| { (*d1-*d2)*(*d1-*d2) }).sum::<f32>().sqrt()
let mut accumulator: f32 = 0.0;
for i in 0..dims { // a.len() {
let delta = a[i] - b[i];
accumulator += delta*delta;
}
accumulator.sqrt()
}
fn _split_points<T>(points: Vec<T>) -> (Vec<T>, Vec<T>) where T: Index<usize, Output=f32> + IntoIterator<Item=f32> {
// Pick the axis with the most variance.
// TODO: Maybe consider another split criteria.
let mut best_split_axis = 0;
let mut best_split_variance = 0.0f32;
let mut best_axis_mean = 0.0;
for axis in 0..3 { //points[0].len() {
let mut mean = 0.0;
for p in points.iter() {
mean += p[axis];
}
mean /= points.len() as f32;
let mut variance = 0.0;
for p in points.iter() {
variance += (p[axis] - mean) * (p[axis]-mean);
}
if variance > best_split_variance {
best_split_axis = axis;
best_split_variance = variance;
best_axis_mean = mean;
}
}
// Now we do the actual split.
let mut lower = vec![];
let mut upper = vec![];
for p in points.into_iter() {
if p[best_split_axis] < best_axis_mean {
lower.push(p);
} else {
upper.push(p);
}
}
assert!(!lower.is_empty());
assert!(!upper.is_empty());
(lower, upper)
}
impl<T> RTree<T> where T: Clone + Index<usize, Output=f32> + IntoIterator<Item=f32> {
pub fn new(origin: T) -> Self {
let mut lower: Vec<f32> = vec![];
let mut upper: Vec<f32> = vec![];
let mut dims = 0;
for p in origin.clone().into_iter() {
lower.push(p);
upper.push(p);
dims += 1;
}
RTree {
num_children: 1,
lower_bounds: lower,
upper_bounds: upper,
dimensions: dims,
children: vec![],
points: vec![origin],
distance_fn: euclidean_distance
}
}
pub fn get_max_depth(&self) -> usize {
let mut max_depth = 0;
for c in self.children.iter() {
max_depth = max_depth.max(c.get_max_depth()+1);
}
max_depth
}
pub fn in_bounds(&self, point: &T) -> bool {
for axis in 0..self.dimensions {
if point[axis] < self.lower_bounds[axis] || point[axis] > self.upper_bounds[axis] {
return false;
}
}
return true;
}
pub fn distance_to_bounds(&self, point: &T) -> f32 {
let mut distance: f32 = 0.0;
for axis in 0..self.dimensions {
let axis_delta:f32 = (self.lower_bounds[axis] - point[axis]).abs().min((self.upper_bounds[axis] - point[axis]).abs());
distance += axis_delta*axis_delta;
}
distance.sqrt()
}
pub fn find_nearest_with_distance(&self, point: &T, max_distance: Option<f32>) -> Option<(&T, f32)> {
let mut min_distance_found = if let Some(dist) = max_distance { dist } else { BIG_NUMBER };
// If the point is not in these bounds then none.
if !self.in_bounds(point) && self.distance_to_bounds(point) > min_distance_found { return None; }
let mut nearest_point_and_distance: Option<(&T, f32)> = None;
// Would be nice to do this.
/*
let mut iterator = if !self.points.is_empty() {
self.points.iter()
} else {
self.children.iter().map(|rt| { rt.find_nearest(point, max_distance) }).filter(|p|p.is_some()).map(|p|p.unwrap()).iter()
};
for p in iterator { ...
*/
if !self.points.is_empty() {
for p in &self.points {
let dist = (self.distance_fn)(p, point, self.dimensions);
if dist < min_distance_found {
min_distance_found = dist;
nearest_point_and_distance = Some((p, min_distance_found));
}
}
} else {
for c in &self.children {
if let Some(p_dist) = c.find_nearest_with_distance(point, Some(min_distance_found)) {
if p_dist.1 < min_distance_found {
nearest_point_and_distance = Some(p_dist);
min_distance_found = p_dist.1;
}
}
}
}
return nearest_point_and_distance;
}
pub fn find_nearest(&self, point: &T, max_distance: Option<f32>) -> Option<&T> {
if let Some((p, _)) = self.find_nearest_with_distance(point, max_distance) {
return Some(p);
}
None
}
pub fn insert(&mut self, point: T) {
if self.children.is_empty() { // Plain insert (leaf):
for axis in 0..self.dimensions {
self.lower_bounds[axis] = self.lower_bounds[axis].min(point[axis]);
self.upper_bounds[axis] = self.upper_bounds[axis].max(point[axis]);
}
self.points.push(point);
self.num_children += 1;
if self.num_children > SPLIT_POINT {
let pts = mem::take(&mut self.points);
let (mut left, mut right) = _split_points(pts);
let mut new_left = RTree::new(left.pop().unwrap());
let mut new_right = RTree::new(right.pop().unwrap());
for p in left.into_iter() {
new_left.insert(p);
}
for p in right.into_iter() {
new_right.insert(p);
}
self.children.push(new_left);
self.children.push(new_right);
}
} else { // Plain insert (non-leaf):
let mut child_idx = 0;
let mut nearest_bounds_distance = BIG_NUMBER;
for (idx, c) in self.children.iter().enumerate() {
let dist = c.distance_to_bounds(&point);
if dist < nearest_bounds_distance {
child_idx = idx;
nearest_bounds_distance = dist;
}
}
// TODO: A heuristic with the best child to give the point.
// We want to balance the sizes and also the number of points in each.
self.children[child_idx].insert(point);
self.num_children += 1;
}
}
}
#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
use super::*;
use rand::{Rng, thread_rng};
use rand::prelude::ThreadRng;
type Point = Vec<f32>;
fn make_point(rng: &mut ThreadRng, low: f32, high: f32, dims: usize) -> Point {
let mut v = vec![];
for _ in 0..dims {
v.push(rng.gen_range(low..=high));
}
v
}
#[test]
fn sanity_vec() {
let a = vec![0.0, 0.0, 0.0];
let mut tree = RTree::new(a);
let b = vec![1.0, 2.0, 3.0];
tree.insert(b);
let p = tree.find_nearest(&vec![0.0, 0.1, 0.0], Some(1.0));
assert!(p.is_some());
assert_eq!(*p.unwrap(), vec![0.0, 0.0, 0.0]);
}
#[test]
fn sanity_slice() {
let a = &[0.0, 0.0, 0.0f32];
let mut tree = RTree::new(a.clone());
let b = &[1.0, 0.0, 0.0];
let maybe_a = tree.find_nearest(b, None);
assert_eq!(maybe_a.unwrap(), a);
}
#[test]
fn stress() {
let MAX_DIST = Some(0.01f32);
let DIMS = 3;
let NUM_POINTS = 100_000;
let NUM_SEARCHES = 1000;
let mut rng = thread_rng();
let mut big_list = vec![make_point(&mut rng, -1e5, 1e5, DIMS)];
let mut tree = RTree::new(make_point(&mut rng, -1e5, 1e5, DIMS));
let mut insert_time = Duration::new(0, 0);
for _ in 0..NUM_POINTS {
let p = make_point(&mut rng, -1e5, 1e5, DIMS);
big_list.push(p.clone());
let start_insert = Instant::now();
tree.insert(p);
let end_insert = Instant::now();
insert_time += end_insert - start_insert;
}
println!("Average insert time: {} nanoseconds", &insert_time.as_nanos()/NUM_POINTS);
let mut list_search_time = Duration::new(0, 0);
let mut tree_search_time = Duration::new(0, 0);
for _ in 0..NUM_SEARCHES {
let p = make_point(&mut rng, -1e5, 1e5, DIMS);
// Linear list search:
let start_search = Instant::now();
let mut nearest_dist = 1e30;
let mut nearest_index = 0;
for (idx, list_p) in big_list.iter().enumerate() {
let dist = euclidean_distance(&p, &list_p, DIMS);
if dist < nearest_dist {
nearest_dist = dist;
nearest_index = idx;
}
}
let end_search = Instant::now();
list_search_time += end_search-start_search;
// Tree search:
let start_search = Instant::now();
tree.find_nearest(&p, MAX_DIST);
let end_search = Instant::now();
tree_search_time += end_search-start_search;
}
println!("Average list search time: {} nanoseconds", &list_search_time.as_nanos()/NUM_POINTS);
println!("Average tree search time: {} nanoseconds", &tree_search_time.as_nanos()/NUM_POINTS);
println!("Max tree depth: {}", &tree.get_max_depth());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment