Last active
May 16, 2023 20:45
-
-
Save FredTheDino/7c348f75b2f761532f72549be2ac07bd to your computer and use it in GitHub Desktop.
A simple Hindley Milner typechecker implemented for Lambda Calculus
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
#[derive(Clone, Debug)] | |
enum Ast { | |
Unit, | |
Var(&'static str), // All strings are uniq | |
Fun(&'static str, Box<Ast>), | |
Call(Box<Ast>, Box<Ast>), | |
} | |
fn unit() -> Ast { | |
Ast::Unit | |
} | |
fn var(v: &'static str) -> Ast { | |
Ast::Var(v) | |
} | |
fn fun(x: &'static str, body: Ast) -> Ast { | |
Ast::Fun(x, Box::new(body)) | |
} | |
fn call(x: Ast, fun: Ast) -> Ast { | |
Ast::Call(Box::new(x), Box::new(fun)) | |
} | |
fn interpret(ast: Ast, vars: &mut HashMap<&'static str, Ast>) -> Option<Ast> { | |
match ast { | |
Ast::Unit => Some(Ast::Unit), | |
Ast::Var(var) => vars.get(var).cloned(), | |
Ast::Fun(var, body) => Some(Ast::Fun(var, body)), | |
Ast::Call(var, fun) => { | |
if let Ast::Fun(arg, body) = interpret(*fun, vars)? { | |
let var = interpret(*var, vars)?; | |
vars.insert(arg, var); | |
interpret(*body, vars) | |
} else { | |
panic!("CANNOT CALL NON FUNCTION!"); | |
} | |
} | |
} | |
} | |
#[derive(Clone, Debug)] | |
enum Type { | |
Unknown, | |
Node(usize), | |
Unit, | |
Fun(Box<Type>, Box<Type>), | |
} | |
#[derive(Clone, Debug)] | |
struct Ctx { | |
tys: Vec<Type>, | |
names: HashMap<&'static str, usize>, | |
} | |
impl Ctx { | |
fn new() -> Self { | |
Ctx { | |
tys: Vec::new(), | |
names: HashMap::new(), | |
} | |
} | |
fn generic_for_var(&mut self, var: &'static str) -> Type { | |
match self.names.entry(var) { | |
std::collections::hash_map::Entry::Occupied(x) => Type::Node(*x.get()), | |
std::collections::hash_map::Entry::Vacant(n) => { | |
let id = self.tys.len(); | |
self.tys.push(Type::Unknown); | |
Type::Node(id) | |
} | |
} | |
} | |
fn new_generic(&mut self) -> Type { | |
let id = self.tys.len(); | |
self.tys.push(Type::Unknown); | |
Type::Node(id) | |
} | |
fn replace(&mut self, a: usize, other: Type) -> Result<usize, &'static str> { | |
if let Type::Node(aa) = self.tys[a] { | |
let inner_a = self.replace(aa, other)?; | |
self.tys[a] = Type::Node(inner_a); | |
Ok(inner_a) | |
} else { | |
let ty_a = self.tys[a].clone(); | |
let ty_imp = unify(self, ty_a, other)?; | |
self.tys[a] = ty_imp; | |
Ok(a) | |
} | |
} | |
} | |
fn expr( | |
ctx: &mut Ctx, | |
ast: Ast, | |
vars: &mut HashMap<&'static str, Type>, | |
) -> Result<Type, &'static str> { | |
match ast { | |
Ast::Unit => Ok(Type::Unit), | |
Ast::Var(var) => { | |
if let Some(ty) = vars.get(var) { | |
Ok(ty.clone()) | |
} else { | |
Err("Unknown variable!") | |
} | |
} | |
Ast::Fun(var, body) => { | |
let ty = ctx.generic_for_var(var); | |
vars.insert(var, ty.clone()); | |
Ok(Type::Fun(Box::new(ty), Box::new(expr(ctx, *body, vars)?))) | |
} | |
Ast::Call(arg, fun) => { | |
let fun_ty = expr(ctx, *fun, vars)?; | |
let ret_ty = ctx.new_generic(); | |
let infered_ty = Type::Fun(Box::new(expr(ctx, *arg, vars)?), Box::new(ret_ty.clone())); | |
unify(ctx, fun_ty, infered_ty)?; | |
Ok(ret_ty) | |
} | |
} | |
} | |
fn unify(ctx: &mut Ctx, a: Type, b: Type) -> Result<Type, &'static str> { | |
match (a, b) { | |
(Type::Unknown, guess) | (guess, Type::Unknown) => Ok(guess), | |
(Type::Node(a), other) | (other, Type::Node(a)) => { | |
Ok(Type::Node(ctx.replace(a, other)?)) | |
} | |
(Type::Unit, Type::Unit) => Ok(Type::Unit), | |
(Type::Fun(aa, ab), Type::Fun(ba, bb)) => { | |
let a = unify(ctx, *aa, *ba)?; | |
let b = unify(ctx, *ab, *bb)?; | |
Ok(Type::Fun(Box::new(a), Box::new(b))) | |
} | |
(a, b) => { | |
println!(":( {:?} - {:?}", a, b); | |
Err("Failed to unify!") | |
} | |
} | |
} | |
fn id(v: &'static str) -> Ast { | |
fun(v, var(v)) | |
} | |
fn main() { | |
let program = call(id("y"), id("x")); | |
println!("PROGRAM: {:?}", program.clone()); | |
let mut ctx =Ctx::new(); | |
let typ = expr(&mut ctx, program.clone(), &mut HashMap::new()); | |
println!("TYPE: {:?}", typ); | |
println!("CTX: {:?}", ctx); | |
println!("EVAL: {:?}", interpret(program, &mut HashMap::new())); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment