Created
December 13, 2014 23:17
-
-
Save omaskery/a69ff5fec665cbcf6aa5 to your computer and use it in GitHub Desktop.
attempt at a matrix implementation in Rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::default::Default; | |
use std::num::Zero; | |
pub struct Matrix<T> { | |
values: Vec<T>, | |
rows: uint, | |
columns: uint, | |
} | |
struct BaseMatrixIter { | |
rows: uint, | |
columns: uint, | |
r: uint, | |
c: uint, | |
} | |
struct MatrixIter<'a, T> where T: 'a { | |
m: &'a Matrix<T>, | |
r: uint, | |
c: uint, | |
} | |
struct MatrixMutIter<'a, T> where T: 'a { | |
m: &'a mut Matrix<T>, | |
r: uint, | |
c: uint, | |
} | |
impl<T> Matrix<T> where T: Default { | |
pub fn new(r: uint, c: uint) -> Matrix<T> { | |
Matrix::from_fn(r, c, |_,_| Default::default()) | |
} | |
} | |
impl<T> Matrix<T> { | |
pub fn from_fn(rows: uint, columns: uint, f: |uint,uint| -> T) -> Matrix<T> { | |
Matrix { | |
rows: rows, | |
columns: columns, | |
values: Matrix::<T>::coord_iter(rows, columns).map(|(r, c)| f(r, c)).collect(), | |
} | |
} | |
fn coord_iter(rows: uint, columns: uint) -> BaseMatrixIter { | |
BaseMatrixIter { | |
rows: rows, | |
columns: columns, | |
r: 0, | |
c: 0, | |
} | |
} | |
} | |
impl<T> Matrix<T> { | |
pub fn set(&mut self, r: uint, c: uint, v: T) { | |
let index = self.index(r, c); | |
self.values[index] = v; | |
} | |
pub fn get<'a>(&'a self, r: uint, c: uint) -> &'a T { | |
let index = self.index(r, c); | |
&self.values[index] | |
} | |
pub fn access<'a>(&'a mut self, r: uint, c: uint) -> &'a mut T { | |
let index = self.index(r, c); | |
&mut self.values[index] | |
} | |
fn index(&self, r: uint, c: uint) -> uint { | |
self.columns * r + c | |
} | |
fn binop<'a>(&'a self, other: &'a Matrix<T>, f: |&'a T,&'a T| -> T) -> Matrix<T> { | |
if self.rows != other.rows || self.columns != other.columns { | |
panic!("mismatched matrix sizes for binop ({}x{}) versus ({}x{})", | |
self.rows, self.columns, other.rows, other.columns | |
); | |
} | |
Matrix::from_fn(self.rows, self.columns, | |
|r, c| { | |
f(self.get(r, c), other.get(r, c)) | |
} | |
) | |
} | |
fn iter<'a>(&'a self) -> MatrixIter<'a, T> { | |
MatrixIter { | |
m: self, | |
r: 0, | |
c: 0, | |
} | |
} | |
fn iter_mut<'a>(&'a mut self) -> MatrixMutIter<'a, T> { | |
MatrixMutIter { | |
m: self, | |
r: 0, | |
c: 0, | |
} | |
} | |
} | |
impl Iterator<(uint, uint)> for BaseMatrixIter { | |
fn next(&mut self) -> Option<(uint, uint)> { | |
match (self.r, self.c) { | |
(r, c) if r < self.columns => { | |
self.c += 1; | |
if self.c >= self.columns { | |
self.c = 0; | |
self.columns += 1; | |
} | |
Some((r, c)) | |
}, | |
_ => None, | |
} | |
} | |
} | |
impl<'a, T> Iterator<(uint, uint, &'a T)> for MatrixIter<'a, T> where T: Copy { | |
fn next(&mut self) -> Option<(uint, uint, &'a T)> { | |
if self.r < self.m.rows { | |
let result = (self.r, self.c, self.m.get(self.r, self.c)); | |
self.c += 1; | |
if self.c >= self.m.columns { | |
self.c = 0; | |
self.r += 1; | |
} | |
Some(result) | |
} else { | |
None | |
} | |
} | |
} | |
impl<'a, T> Iterator<(uint, uint, &'a mut T)> for MatrixMutIter<'a, T> where T: Copy { | |
fn next(&mut self) -> Option<(uint, uint, &'a mut T)> { | |
if self.r < self.m.rows { | |
let value: &'a mut T = self.m.access(self.r, self.c); | |
let result = (self.r, self.c, value); | |
self.c += 1; | |
if self.c >= self.m.columns { | |
self.c = 0; | |
self.r += 1; | |
} | |
Some(result) | |
} else { | |
None | |
} | |
} | |
} | |
impl<T> Add<Matrix<T>, Matrix<T>> for Matrix<T> where T: Add<T, T> { | |
fn add(&self, other: &Matrix<T>) -> Matrix<T> { | |
self.binop(other, |a,b| *a + *b) | |
} | |
} | |
impl<T> Sub<Matrix<T>, Matrix<T>> for Matrix<T> where T: Sub<T, T> { | |
fn sub(&self, other: &Matrix<T>) -> Matrix<T> { | |
self.binop(other, |a,b| *a - *b) | |
} | |
} | |
impl<T> Mul<Matrix<T>, Matrix<T>> for Matrix<T> where T: Mul<T, T> + Zero { | |
fn mul(&self, other: &Matrix<T>) -> Matrix<T> { | |
if self.columns != other.rows { | |
panic!("mismatched matrix sizes for multiply ({}x{}) versus ({}x{})", | |
self.rows, self.columns, other.rows, other.columns | |
); | |
} | |
Matrix::from_fn(self.rows, other.columns, | |
|r, c| { | |
let result: T = Zero::zero(); | |
for i in range(0, self.columns) { | |
result = result + (*self.get(r, c+i) * *self.get(r+i, c)); | |
} | |
result | |
} | |
) | |
} | |
} | |
impl<T> Mul<T, Matrix<T>> for Matrix<T> where T: Mul<T, T> { | |
fn mul(&self, other: &T) -> Matrix<T> { | |
Matrix::from_fn(self.rows, self.columns, |r, c| *self.get(r, c) * *other) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment