/// Playground - noun: a place where people can play /// I am the very model of a modern Judgement General //: # Algorithm W //: In this playground we develop a complete implementation of the classic //: algorithm W for Hindley-Milner polymorphic type inference in Swift. //: ## Introduction //: Type inference is a tricky business, and it is even harder to learn //: the basics, because most publications are about very advanced topics //: like rank-N polymorphism, predicative/impredicative type systems, //: universal and existential types and so on. Since I learn best by //: actually developing the solution to a problem, I decided to write a //: basic tutorial on type inference, implementing one of the most basic //: type inference algorithms which has nevertheless practical uses as the //: basis of the type checkers of languages like ML or Haskell. //: //: The type inference algorithm studied here is the classic Algoritm W //: proposed [by Milner](http://web.cs.wpi.edu/~cs4536/c12/milner-type-poly.pdf). //: For a very readable presentation of this algorithm and possible variations and extensions //: [read also Heeren et al.](https://pdfs.semanticscholar.org/8983/233b3dff2c5b94efb31235f62bddc22dc899.pdf). //: Several aspects of this tutorial are also inspired by [Jones](http://web.cecs.pdx.edu/~mpj/thih/thih.pdf). //: ## Preliminaries //: We start by defining the abstract syntax for both *expressions* //: (of type `Expression`), *types* (`Type`) and *type schemes* //: (`Scheme`). indirect enum Expression : CustomStringConvertible { case evar(String) case lit(Literal) case app(Expression, Expression) case abs(String, Expression) case elet(String, Expression, Expression) case fix(String, String, Expression) var description : String { switch self { case let .evar(name): return name case let .lit(lit): return lit.description case let .elet(x, b, body): return "let \(x) = \(b) in \(body)" case let .app(e1, e2): switch e2 { case .elet(_, _, _): return "\(e1)(\(e2))" case .abs(_, _): return "\(e1)(\(e2))" case .app(_, _): return "\(e1)(\(e2))" default: return "\(e1) \(e2)" } case let .abs(n, e): return "λ \(n) -> \(e)" case let .fix(f, n, e): return "fix \(f) \(Expression.abs(n, e))" } } } enum Literal : CustomStringConvertible { case int(Int) case bool(Bool) var description : String { switch self { case let .int(i): return i.description case let .bool(b): return b ? "True" : "False" } } } indirect enum Type : Equatable, CustomStringConvertible { case typeVar(String) case int case bool case arrow(Type, Type) static func ==(l : Type, r : Type) -> Bool { switch (l, r) { case let (.typeVar(ls), .typeVar(rs)): return ls == rs case (.int, .int): return true case (.bool, .bool): return true case let (.arrow(l1, r1), .arrow(l2, r2)): return l1 == l2 && r1 == r2 default: return false } } var description : String { switch self { case let .typeVar(n): return n case .int: return "Int" case .bool: return "Bool" case let .arrow(.arrow(ss, sr), t2): return "(\(ss) -> \(sr)) -> \(t2)" case let .arrow(t1, t2): return "\(t1) -> \(t2)" } } } struct Scheme : CustomStringConvertible { let vars : [String] let type : Type var description : String { return "∀ \(self.vars.joined(separator: ", ")) . \(self.type)" } } //: We will need to determine the free type variables of a type. The `freeTypeVariables` accessor //: implements this operation, which appears as a requirement of the `TypeCarrier` protocol because //: it will also be needed for type environments (to be defined below). Another useful operation on //: types, type schemes and the like is that of applying a substitution. protocol TypeCarrier { var freeTypeVariables : Set<String> { get } func apply(_ : Substitution) -> Self } extension Type : TypeCarrier { var freeTypeVariables : Set<String> { switch self { case let .typeVar(n): return Set([ n ]) case .int: return Set() case .bool: return Set() case let .arrow(t1, t2): return t1.freeTypeVariables.union(t2.freeTypeVariables) } } func apply(_ s : Substitution) -> Type { // To apply a substitution to a type: switch self { case let .typeVar(n): // If it's a type variable, look it up in the substitution map to // find a replacement. if let t = s[n] { // If we get replaced with ourself we've reached the desired fixpoint. if t == self { return t } // Otherwise keep substituting. return t.apply(s) } return self case .int: // Literals can't be substituted for. return self case .bool: // Literals can't be substituted for. return self case let .arrow(t1, t2): // Substitute down the input and output types of an arrow. return .arrow(t1.apply(s), t2.apply(s)) } } } extension Scheme : TypeCarrier { var freeTypeVariables : Set<String> { return self.type.freeTypeVariables.subtracting(Set(self.vars)) } func apply(_ s : Substitution) -> Scheme { // To apply a substitution to a type scheme, knock out all the things that are no longer // type variables. return Scheme(vars: self.vars, type: self.type.apply(self.vars.reduce(into: s) { (map, k) in map.removeValue(forKey: k) })) } } //: Now we define substitutions, which are finite mappings from type //: variables to types. typealias Substitution = Dictionary<String, Type> //: Type environments, called Γ in the text, are mappings from term //: variables to their respective type schemes. struct Γ { let unEnv : Dictionary<String, Scheme> //: We define several functions on type environments. The operation //: `Γ \ x` removes the binding for `x` from `Γ`. Naturally, we call it //: `remove`. func remove(_ v : String) -> Γ { var dict = self.unEnv dict.removeValue(forKey: v) return Γ(unEnv: dict) } //: The `generalize` function abstracts a type over all type variables //: which are free in the type but not free in the given type environment. // Γ ⊢ e : σ α ∉ free(Γ) // −−−−−−−−−−−−−−−−−------− // Γ ⊢ e : ∀ α . σ func generalize(_ t : Type) -> Scheme { return Scheme(vars: t.freeTypeVariables.subtracting(self.freeTypeVariables).map { $0 }, type: t) } } extension Γ : TypeCarrier { var freeTypeVariables : Set<String> { return self.unEnv.reduce(Set<String>(), { (acc, t) -> Set<String> in return acc.union(t.value.freeTypeVariables) }) } func apply(_ s : Substitution) -> Γ { return Γ(unEnv: Dictionary(self.unEnv.map { (k, v) in (k, v.apply(s)) }, uniquingKeysWith: {$1})) } } enum TypeError : Error { case unificationFailed(Type, Type) case occursCheckFailed(String, Type) case foundUnbound(String) var description : String { switch self { case let .unificationFailed(t1, t2): return "types do not unify: \(t1) vs. \(t2)" case let .occursCheckFailed(u, t): return "'occurs' check failed: \(u) appears in its own type: \(t)" case let .foundUnbound(n): return "unbound variable \(n)" } } } //: Several operations, for example type scheme instantiation, require //: fresh names for newly introduced type variables. This is implemented //: by using an appropriate monad which takes care of generating fresh //: names. It is also capable of passing a dynamically scoped //: environment, error handling and performing I/O, but we will not go //: into details here. final class Inferencer { var supply : Int = 0 //: The instantiation function replaces all bound type variables in a type //: scheme with fresh type variables. //: //: In more creative type systems, this rule acts as a trapdoor that lets //: you insert whatever you want "subtype" to mean - note this is not necessarily the //: subtype in the object-oriented sense, but rather in a type-scheme sense. //: For example, the type `int -> int` is subtype of `∀ a . a -> a` because we //: only intend "subtype" to mean "is more/less specialized than". // Γ ⊢ e : σ′ σ′ ⊑ σ // −−−−−−−−−−−−−−------- // Γ ⊢ e : σ func instantiate(_ s : Scheme) throws -> Type { let nvars = s.vars.map { _ in self.createTypeVariable(named: "τ") } let subst = Substitution(zip(s.vars, nvars), uniquingKeysWith: {$1}) return s.type.apply(subst) } //: This is the unification function for types. func unify(_ t1 : Type, with t2 : Type) throws -> Substitution { switch (t1, t2) { case let (.arrow(l, r), .arrow(l2, r2)): let subst1 = try self.unify(l, with: l2) let subst2 = try self.unify(r.apply(subst1), with: r2.apply(subst1)) return subst1.merging(subst2, uniquingKeysWith: {$1}) case let (.typeVar(tv), ty): return try self.bindTypeVariable(tv, to: ty) case let (ty, .typeVar(tv)): return try self.bindTypeVariable(tv, to: ty) case (.int, .int): return [:] case (.bool, .bool): return [:] default: throw TypeError.unificationFailed(t1, t2) } } //: The function `bindTypeVariable` attempts to bind a type variable to a type //: and return that binding as a subsitution, but avoids binding a //: variable to itself and performs the occurs check. private func bindTypeVariable(_ u : String, to t : Type) throws -> Substitution { if t == .typeVar(u) { return [:] } else if t.freeTypeVariables.contains(u) { throw TypeError.occursCheckFailed(u, t) } else { return [u:t] } } //: ## Main type inference function //: The function `inferType` infers the types for expressions. The type //: environment must contain bindings for all free variables of the //: expressions. The returned substitution records the type constraints //: imposed on type variables by the expression, and the returned type is //: the type of the expression. //: Algorithm W takes a "top-down" (context-free) approach to type inference at the //: expense of failing later and generating more constraints than Algorithm M. As //: presented here, Algorithm W will only fail at the boundary of an application expression //: where function and argument fail to unify. By the time that occurs, type checking of //: both will have been performed, and the resulting error will apply to the entirety of the //: application rather than an erroneous subexpression. //: //: Algorithm W was proven sound and complete [by Milner](http://web.cs.wpi.edu/~cs4536/c12/milner-type-poly.pdf) //: in his original presentation of the type system for which this algorithm was built. private func inferTypeW(of exp : Expression, in env : Γ) throws -> (Substitution, Type) { switch exp { // x : σ ∈ Γ // --−−−−−−− // Γ ⊢ x : σ case let .evar(n): // The only thing we can do is lookup in the context to check if // it contains an entry for the variable we're interested in. If it // doesn't then the variable must be unbound. guard let sigma = env.unEnv[n] else { throw TypeError.foundUnbound(n) } let t = try self.instantiate(sigma) return ([:], t) // We require no premises to infer types for boolean and integer literals. case let .lit(l): switch l { // // --−−−−−−−------ // Γ ⊢ true : bool // // --−−−−−−−------- // Γ ⊢ false : bool case .bool(_): return ([:], Type.bool) // // --−−−−−−−------- // Γ ⊢ [0-9]+ : int case .int(_): return ([:], Type.int) } // Γ , x : τ ⊢ e : τ′ // −−−−−−−−−−−−−−------ // Γ ⊢ λ x . e : τ → τ′ case let .abs(n, body): // Create a new type variable to solve for the abstraction. let tv = self.createTypeVariable(named: "τ") // Setup a new environment that binds the type of the bound variable to our new type variable. var updatedEnv = env.unEnv updatedEnv[n] = Scheme(vars: [], type: tv) // Infer the type of the body. let (bodySubst, bodyTy) = try self.inferTypeW(of: body, in: Γ(unEnv: updatedEnv)) return (bodySubst, .arrow(tv.apply(bodySubst), bodyTy)) // Γ ⊢ e0 : τ → τ′ Γ ⊢ e1 : τ // −--------−−−−−−−−−−−−−−−−−−−− // Γ ⊢ e0(e1) : τ′ case let .app(e0, e1): // Create a new type variable to solve for the application. let tv = self.createTypeVariable(named: "τ") // Infer the type of the function... let (funcSubst, funcTy) = try self.inferTypeW(of: e0, in: env) // Then infer the type of the argument. let (argSubst, argTy) = try self.inferTypeW(of: e1, in: env.apply(funcSubst)) // Now, apply function type to argument type to get back whatever substitutions // we need to perform to yield the correct output type. let subst = try self.unify(funcTy.apply(argSubst), with: .arrow(argTy, tv)) // Yield the output with those substitutions. Dump all the work we've done // up to this point in the context for good measure, too. return (subst.merging(argSubst, uniquingKeysWith: {$1}).merging(funcSubst, uniquingKeysWith: {$1}), tv.apply(subst)) // Γ ⊢ e0 : σ Γ , x : σ ⊢ e1 : τ // −------------−−−−−−−−−−−−−−−−−−−− // Γ ⊢ let x = e0 in e1 : τ case let .elet(x, e0, e1): // First, infer the type of the body of the binding. let (boundSubst, boundTy) = try self.inferTypeW(of: e0, in: env) // Update the context with that information. var updatedEnv = env.unEnv updatedEnv[x] = env.apply(boundSubst).generalize(boundTy) // Now infer the type of the body let (bodySubst, bodyTy) = try self.inferTypeW(of: e1, in: Γ(unEnv: updatedEnv).apply(boundSubst)) return (boundSubst.merging(bodySubst, uniquingKeysWith: {$1}), bodyTy) // This rule (recursion) is not a part of the original system. Nonetheless its typing rules // are easy enough we may as well support it. // // Γ, f : τ ⊢ λ x . e : τ // ---------------------- // Γ ⊢ fix f λx . e : τ case let .fix(f, n, body): // Create a new type variable to solve for the fixpoint. let tv = self.createTypeVariable(named: "τ") // Setup a new environment that binds the type of the recursor. var updatedEnv = env.unEnv updatedEnv[f] = Scheme(vars: [], type: tv) // Infer the type of the lambda. let (lamSubst, lamTy) = try self.inferTypeW(of: .abs(n, body), in: Γ(unEnv: updatedEnv)) let subst = try self.unify(tv.apply(lamSubst), with: lamTy) return (subst.merging(lamSubst, uniquingKeysWith: {$1}), lamTy.apply(subst)) } } //: Algorithm M carries a type constraint from the context of an expression and stops when the //: expression cannot satisfy the current type constraint. //: //: Algorithm M was proven [by Lee and Yi](https://ropas.snu.ac.kr/~kwang/paper/98-toplas-leyi.pdf) to be //: sound and complete. It also has the desirable property that it generates less constraints overall than //: Algorithm W and has better locality of type errors at the expense needing to carry a context through the //: computation. private func inferTypeM(of exp : Expression, in env : Γ, against rho : Type) throws -> Substitution { switch exp { // x : σ ∈ Γ // --−−−−−−− // Γ ⊢ x : σ case let .evar(n): // The only thing we can do is lookup in the context to check if // it contains an entry for the variable we're interested in. If it // doesn't then the variable must be unbound. guard let sigma = env.unEnv[n] else { throw TypeError.foundUnbound(n) } let t = try self.instantiate(sigma) return try self.unify(rho, with: t) // We require no premises to infer types for boolean and integer literals. case let .lit(l): switch l { // // --−−−−−−−------ // Γ ⊢ true : bool // // --−−−−−−−------- // Γ ⊢ false : bool case .bool(_): return try self.unify(rho, with: .bool) // // --−−−−−−− // Γ ⊢ [0-9]+ : int case .int(_): return try self.unify(rho, with: .int) } // Γ , x : τ ⊢ e : τ′ // −−−−−−−−−−−−−−------ // Γ ⊢ λ x . e : τ → τ′ case let .abs(n, body): // Create a new type variable to solve for the abstraction. let inputTV = self.createTypeVariable(named: "τ") let outputTV = self.createTypeVariable(named: "τ") // Check that the type variable we have is an arrow. let arrowSubst = try self.unify(rho, with: .arrow(inputTV, outputTV)) // Setup a new environment that binds the type of the bound variable to our new type variable. var updatedEnv = env.unEnv updatedEnv[n] = Scheme(vars: [], type: inputTV.apply(arrowSubst)) // Infer the type of the body in the updated context. let bodySubst = try self.inferTypeM(of: body, in: Γ(unEnv: updatedEnv).apply(arrowSubst), against: outputTV.apply(arrowSubst)) return bodySubst.merging(arrowSubst, uniquingKeysWith: {$1}) // Γ ⊢ e0 : τ → τ′ Γ ⊢ e1 : τ // −--------−−−−−−−−−−−−−−−−−−−− // Γ ⊢ e0(e1) : τ′ case let .app(e0, e1): // Create a new type variable to solve for the application. let tv = self.createTypeVariable(named: "τ") // let funcSubst = try self.inferTypeM(of: e0, in: env, against: .arrow(tv, rho)) let argSubst = try self.inferTypeM(of: e1, in: env.apply(funcSubst), against: tv.apply(funcSubst)) return argSubst.merging(funcSubst, uniquingKeysWith: {$1}) // Γ ⊢ e0 : σ Γ , x : σ ⊢ e1 : τ // −------------−−−−−−−−−−−−−−−−−−−− // Γ ⊢ let x = e0 in e1 : τ case let .elet(x, e0, e1): let tv = self.createTypeVariable(named: "τ") let boundSubst = try self.inferTypeM(of: e0, in: env, against: tv) // Insert the bound variable into the context. var updatedEnv = env.unEnv updatedEnv[x] = env.generalize(tv.apply(boundSubst)) // Infer the type of the body. let bodySubst = try self.inferTypeM(of: e1, in: Γ(unEnv: updatedEnv).apply(boundSubst), against: rho.apply(boundSubst)) return bodySubst.merging(boundSubst, uniquingKeysWith: {$1}) // Γ, f : τ ⊢ λ x . e : τ // ---------------------- // Γ ⊢ fix f λ x . e : τ case let .fix(f, n, body): // Add the recursor to the environment. var updatedEnv = env.unEnv updatedEnv[f] = Scheme(vars: [], type: rho) // Infer the type of the rest as a lambda term. return try self.inferTypeM(of: Expression.abs(n, body), in: Γ(unEnv: updatedEnv), against: rho) } } typealias Assumptions = [(String, Type)] typealias ConstraintGraph = [Constraint] enum Constraint : CustomStringConvertible { case equivalent(Type, Type) case explicitInstance(Type, Scheme) case implicitInstance(Type, Set<String>, Type) var description : String { switch self { case let .equivalent(t1, t2): return "\(t1) = \(t2)" case let .explicitInstance(t, s): return "\(t) <~ \(s)" case let .implicitInstance(t1, m, t2): return "\(t1) <= (\(m.joined(separator: ","))) \(t2)" } } } //: A cute function that translates the work HM-inference is doing behind the scenes to an explicit //: constraint-based format. func gatherConstraints(for exp : Expression, _ m : Set<String>) -> (Assumptions, ConstraintGraph, Type) { switch exp { case let .evar(n): // If we find a variable, spawn a constraint for it and add it to the list // of assumptions needed to solve this system. let tv = self.createTypeVariable(named: "τ") return ([(n, tv)], [], tv) case .lit(.int(_)): // Literals spawn type variables that are solved immediately by generated constraints. let tv = self.createTypeVariable(named: "τ") return ([], [.equivalent(tv, .int)], tv) case .lit(.bool(_)): // Literals spawn type variables that are solved immediately by generated constraints. let tv = self.createTypeVariable(named: "τ") return ([], [.equivalent(tv, .bool)], tv) case let .app(e0, e1): // For an application, gather the constraints necessary to check the function. let (funcAssump, funcConstr, funcTy) = self.gatherConstraints(for: e0, m) // The gather the constraints needed for the argument. let (argAssump, argConstr, argTy) = self.gatherConstraints(for: e1, m) // Now spawn a type variable for the application itself and take the union of all // generated constraints and assumptions plus a constraint on the output type. let b = self.createTypeVariable(named: "τ") return (funcAssump + argAssump, funcConstr + argConstr + [.equivalent(funcTy, .arrow(argTy, b))], b) case let .abs(x, body): // Spawn a type variable for the abstraction. guard case let .typeVar(vn) = self.createTypeVariable(named: "τ") else { fatalError("Wat?") } // Gather the constraints for the body of the lambda. let (bodyAssump, bodyConstr, bodyTy) = self.gatherConstraints(for: body, m.union([vn])) let b = Type.typeVar(vn) // Replace all the assumptions about the bound variable with fresh ones that are in scope. return (bodyAssump.filter { (n, _) in x != n }, bodyConstr + bodyAssump.filter({ (x2, _) in x == x2 }).map { (x2, t2) in .equivalent(t2, b) }, .arrow(b, bodyTy)) case let .elet(x, e0, e1): // Gather the constraints for the binding. let (bindAssump, bindConstr, bindTy) = self.gatherConstraints(for: e0, m) // The gather them for the body. let (bodyAssump, bodyConstr, bodyTy) = self.gatherConstraints(for: e1, m.subtracting([x])) return ( // Update any assumptions that mention the newly bound variable. bindAssump + bodyAssump.filter { (n, _) in x != n }, bindConstr + bodyConstr + bodyAssump.filter({ (v, _) in x == v }).map { (x2, t2) in .implicitInstance(t2, m, bindTy) }, bodyTy ) case let .fix(f, n, body): // Create a new type variable to solve for the fixpoint. guard case let .typeVar(vn) = self.createTypeVariable(named: "τ") else { fatalError("Wat?") } // Gather the constraints for the body of the fixpoint. let (bodyAssump, bodyConstr, bodyTy) = self.gatherConstraints(for: .abs(n, body), m.union([vn])) let b = Type.typeVar(vn) return ( bodyAssump.filter { (n, _) in f != n }, bodyConstr + bodyAssump.filter({ (x2, _) in f == x2 }).map { (x2, t2) in .equivalent(t2, b) }, bodyTy ) } } private func createTypeVariable(named prefix : String) -> Type { self.supply += 1 return .typeVar(prefix + self.supply.description) } enum Direction { case topDown case bottomUp } //: This is the main entry point to the type inferencer. It simply calls //: `inferType` and applies the returned substitution to the returned type. func inferType(of e : Expression, direction : Direction = .topDown) throws -> Type { switch direction { case .topDown: let (subst, type) = try self.inferTypeW(of: e, in: Γ(unEnv: [:])) return type.apply(subst) case .bottomUp: let tv = self.createTypeVariable(named: "τ") let subst = try self.inferTypeM(of: e, in: Γ(unEnv: [:]), against: tv) return tv.apply(subst) } } } //: ## Tests //: The following simple expressions (partly taken from [Heeren](https://pdfs.semanticscholar.org/8983/233b3dff2c5b94efb31235f62bddc22dc899.pdf)) //: are provided for testing the type inference function. let e0 : Expression = .elet("id", .abs("x", .evar("x")), .evar("id")) let e1 : Expression = .elet("id", .abs("x", .evar("x")), .app(.evar("id"), .evar("id"))) let e2 : Expression = .elet("id", .abs("x", .elet("y", .evar("x"), .evar("y"))), .app(.evar("id"), .evar("id"))) let e3 : Expression = .elet("id", .abs("x", .elet("y", .evar("x"), .evar("y"))), .app(.app(.evar("id"), .evar("id")), .lit(.int(2)))) let e4 : Expression = .elet("id", .abs("x", .app(.evar("x"), .evar("x"))), .evar("id")) let e5 : Expression = .abs("m", .elet("y", .evar("m"), .elet("x", .app(.evar("y"), .lit(.bool(true))), .evar("y")))) let e6 : Expression = .fix("id", "x", .app(.evar("id"), .evar("x"))) //: This simple set of tests tries to infer the type for the given //: expression. If successful, it prints the expression together with its //: type, otherwise, it prints the error message. for exp in [e0, e1, e2, e3, e4, e5, e6] { do { print(exp) print("Algorithm W: ",try Inferencer().inferType(of: exp, direction: .topDown)) print(Inferencer().gatherConstraints(for: exp, Set()).1) print("Algorithm M: ",try Inferencer().inferType(of: exp, direction: .bottomUp)) print("---------") } catch let e as TypeError { print(e.description) print("---------") } }