Skip to content

Instantly share code, notes, and snippets.

@Arrow7000
Last active March 18, 2025 10:28
Show Gist options
  • Save Arrow7000/97f5880818c6bfb12a16269278839374 to your computer and use it in GitHub Desktop.
Save Arrow7000/97f5880818c6bfb12a16269278839374 to your computer and use it in GitHub Desktop.
Very simple type universe and zonking implementation
import Std.Data.HashMap
import Std.Data.HashSet
import Mathlib
open Std
instance {k} [BEq k] [Hashable k] : Hashable (HashSet k) where
hash := hash ∘ HashSet.toArray
def Std.HashSet.intersection {k : Type u} [BEq k] [Hashable k] (a : HashSet k) (b : HashSet k) : HashSet k :=
a.fold (init := ∅) fun s x => if b.contains x then s.insert x else s
def Std.HashSet.intersection_mem_both {k : Type u} [BEq k] [Hashable k] [EquivBEq k] [LawfulHashable k] (a : HashSet k) (b : HashSet k) : ∀ v ∈ (a.intersection b), v ∈ a ∧ v ∈ b := by
rw [intersection]
suffices ∀ a : List k, ∀ q : HashSet k, ∀ (v : k), v ∈ List.foldl (fun s x => if b.contains x = true then s.insert x else s) q a → v ∈ q ∨ a.contains v ∧ v ∈ b by
simpa [HashSet.fold_eq_foldl_toList, HashSet.mem_iff_contains] using this a.toList ∅
intro a q
induction a generalizing q with
| nil => simp
| cons hd tl ih =>
simp
intro v hv
obtain (h|h) := ih _ _ hv
· split at h
· next ht =>
rw [HashSet.mem_insert] at h
obtain (h|h) := h
· refine Or.inr ⟨Or.inl (BEq.symm h), ?_⟩
rwa [HashSet.mem_iff_contains, ← HashSet.contains_congr h]
· exact Or.inl h
· exact Or.inl h
· exact Or.inr ⟨Or.inr h.1, h.2⟩
/-- Gets the items in `a` that are not in `b` -/
def Std.HashSet.diff {k : Type u} [BEq k] [Hashable k] (a : HashSet k) (b : HashSet k) : HashSet k :=
a.filter (not ∘ b.contains)
/--
A disjoint-set map where:
- sub-keys are of type k
- keys are of type Set k
- values are of type Set k → v – i.e. values are dependent on their keys
- union combines two keys into a single key, and therefore requires you to pass a value merging function `(newSet : Set k) → v → v → v`
- findKeys takes a sub-key and returns the full key it is a member of
- findValueByKey takes a full set key and returns the value associated with that key
- findValue takes a sub-key and returns the value associated with the full key it is a member of
- the way to represent this is probably to have a map of sub-keys pointing to the full set key, which in another map points to the value
-/
structure DisjointSetMap (k : Type u) [BEq k] [Hashable k] (v : Type u) where
/-- This maps from _sets_ of `k`s to `v`s -/
innerMap : HashMap (HashSet k) v
/-- How to merge values together when key sets are merged -/
mergeFn : v → v → (keysNew : HashSet k) → v
namespace DisjointSetMap
variable {k : Type u} [instBeq : BEq k] [instHash : Hashable k] {v : Type u}
private def getOuterMapFromInnerMap {k v : Type u} [BEq k] [Hashable k] (innerMap : HashMap (HashSet k) v) : HashMap k (HashSet k × v) :=
innerMap
|>.fold
(fun (map : HashMap k (HashSet k × v)) (keySet : HashSet k) value =>
keySet.fold (fun map' key => map'.insert key (keySet, value)) map)
HashMap.empty
/-- A computed map from the inner map to a map of individual keys to key sets _and_ values; so we can:
a) find items in the map by a single key by doing `d.outerMap[key]`
b) which then returns the full set of keys that the single key belongs to, as well as the value stored for that set: `HashSet k × v`
-/
def outerMap (d : DisjointSetMap k v) : HashMap k (HashSet k × v) :=
getOuterMapFromInnerMap d.innerMap
def empty (mergeFn : v → v → (keysNew : HashSet k) → v) : DisjointSetMap k v :=
{ innerMap := HashMap.empty
mergeFn }
def addSet (d : DisjointSetMap k v) (newKeySet : HashSet k) (val : v) : DisjointSetMap k v :=
let overlappingSets :=
d.innerMap.fold (init := (newKeySet, [])) fun acc currSet value =>
let intersection := acc.1.intersection currSet
if intersection == ∅ then
-- The no overlap case, so there's nothing to add here
acc
else
-- If there is some overlap, we snowball the current set in the accumulated set, and include the value of the newly merged set in the list of values to merge later
let union := acc.1.union currSet
(union, value :: acc.2)
match overlappingSets with
| (_, []) =>
-- No overlaps, just insert the new set with its value
{ d with innerMap := d.innerMap.insert newKeySet val }
| (mergedSet, valuesToMerge) =>
-- Merge all overlapping sets and their values
let mergedValue :=
valuesToMerge.foldl (init := val)
(fun acc value => d.mergeFn acc value mergedSet)
-- Remove old sets
let newInnerMap :=
d.innerMap.fold (init := HashMap.empty) fun newInnerMap currSet _ =>
let hasOverlap := currSet.any newKeySet.contains
if hasOverlap then
newInnerMap.erase currSet
else
newInnerMap
-- And insert the new snowballed merged set with its combined value
{ d with innerMap := newInnerMap.insert mergedSet mergedValue }
/-- This merges multiple sets without adding a new value. If none of the keys are in the map then this is a no-op because we have no value to set for it! -/
def union (d : DisjointSetMap k v) (keysToMerge : HashSet k) : DisjointSetMap k v :=
let overlappingSets :=
d.innerMap.fold (init := (keysToMerge, [])) fun acc currSet value =>
let intersection := acc.1.intersection currSet
if intersection == ∅ then
-- The no overlap case, so there's nothing to add here
acc
else
-- If there is some overlap, we snowball the current set in the accumulated set, and include the value of the newly merged set in the list of values to merge later
let union := acc.1.union currSet
(union, value :: acc.2)
match overlappingSets with
| (_, []) =>
-- No overlaps, none of the keys are in the map so we do nothing because we have no value to set for it
d
| (mergedSet, firstVal :: restValsToMerge) =>
-- Merge all overlapping sets and their values
let mergedValue :=
restValsToMerge.foldl (init := firstVal)
(fun acc value => d.mergeFn acc value mergedSet)
-- Remove old sets
let newInnerMap :=
d.innerMap.fold (init := HashMap.empty) fun newInnerMap currSet _ =>
let hasOverlap := currSet.any keysToMerge.contains
if hasOverlap then
newInnerMap.erase currSet
else
newInnerMap
-- And insert the new snowballed merged set with its combined value
{ d with innerMap := newInnerMap.insert mergedSet mergedValue }
def find? (d : DisjointSetMap k v) (key : k) : Option v :=
d.outerMap[key]? |>.map (fun (_, v) => v)
def find (d : DisjointSetMap k v) (key : k) (h : key ∈ d.outerMap) : v :=
d.outerMap[key]'h |>.2
end DisjointSetMap
import Std.Data.HashMap
import DisjointSetMap -- make sure this points to the right namespace for the module!
open Std
open DisjointSetMap
/-- A type clash between two types that are not compatible. -/
structure Incompatible (t : Type u) where
t1 : t
t2 : t
inductive Primitive where
| unit
| string
| int
| bool
| nat
deriving DecidableEq
structure TypeVar where
index : Nat
deriving DecidableEq
structure UnificationVar where
index : Nat
deriving Hashable, DecidableEq
abbrev UniVar := UnificationVar
/-- A type that does away with the frills, and is purpose built from the ground up for how we do inference: namely it only represents primitives, type constructors, and anything beyond that is a univar that references an `Option SimpleType`. in the UniVarsMap. -/
inductive SimpleType where
| primitive : Primitive → SimpleType
/-- A type name with the type variables -/
| typeCtor : String → List UniVar → SimpleType
instance instMembershipUniVarSimpleType : Membership UniVar SimpleType where
mem container item :=
match container with
| .primitive _ => False
| .typeCtor _ uniVars => uniVars.contains item
instance instMembershipUniVarPrimitive : Membership UniVar Primitive where
mem _ _ := False
instance {a} : Membership UniVar (DisjointSetMap UniVar a) where
mem container item :=
container.outerMap.contains item
instance instMembershipUniVarOptionSimpleType : Membership UniVar (Option SimpleType) where
mem container item :=
match container with
| .some t => item ∈ t
| .none => False
instance instMembershipUniVarHashSetUniVarOptionSimpleType : Membership UniVar (HashSet UniVar × Option SimpleType) where
mem container item := instMembershipUniVarOptionSimpleType.mem container.2 item
abbrev SimpleTypeMap' := DisjointSetMap UniVar (Option SimpleType)
/-- The inductive predicate that a uniVar is a valid key in a `SimpleTypeMap'`, that every univar at a value in the map is also a valid key in the map, and that there are no cycles of univars looping round to point to each other! -/
inductive IsValidUniVarKey (m : SimpleTypeMap') : UniVar → Prop where
| mk (uniVar : UniVar)
(hmem : uniVar ∈ m)
(h : ∀ (uv : UniVar), uv ∈ m.outerMap[uniVar]'hmem → IsValidUniVarKey m uv) : IsValidUniVarKey m uniVar
theorem IsValidUniVarKey.mem (h : IsValidUniVarKey m u) : u ∈ m :=
match h with
| .mk _ mem _ => mem
theorem IsValidUniVarKey.children {m u} (h : IsValidUniVarKey m u) : ∀ (uv : UniVar), uv ∈ m.outerMap[u]'h.mem → IsValidUniVarKey m uv :=
match h with
| .mk _ _ children => children
def SimpleTypeMap := { map : SimpleTypeMap' // ∀ uv ∈ map.outerMap, IsValidUniVarKey map uv }
theorem accessible (m : SimpleTypeMap) (u : UniVar) (h : IsValidUniVarKey m.1 u) :
Acc (fun x y => x.1 = y.1 ∧ ∃ h : y.2 ∈ y.1.1, x.2 ∈ y.1.1.outerMap[y.2]'h) (m, u) := by
induction h with | mk u mem h ih =>
constructor
intro (m', u')
rintro ⟨rfl, mem, ht⟩
dsimp only at h ih mem ht
exact ih u' ht
instance : WellFoundedRelation (SimpleTypeMap × UniVar) where
-- x "<" y
rel x y := x.1 = y.1 ∧ ∃ h : y.2 ∈ y.1.1, x.2 ∈ y.1.1.outerMap[y.2]'h
wf := by
constructor
intro a
constructor
intro (m, u) ⟨h₁, h₂, ht⟩
have h := a.1.2 a.2 h₂
apply accessible
cases h₁
exact h.children u ht
@[simp] theorem SimpleType.uniVar_in_typeCtor {uv : UniVar} {ctor : String} {uniVars : List UniVar} : (uv ∈ SimpleType.typeCtor ctor uniVars) = (uv ∈ uniVars) := by
simp [instMembershipUniVarSimpleType]
theorem HashMap.getElemPrfFrom2 {k v1 v2 : Type u} {someVal : v2} [DecidableEq k] [Hashable k] {key : k} {map : HashMap k (v1 × v2)} {h : key ∈ map} (_ : map[key].2 = someVal) : key ∈ map := by
exact h
/- ZONKING -/
inductive SimpleZonkedType where
| primitive : Primitive → SimpleZonkedType
/-- A type variable. Might be slightly tricky to make sure we keep typevars that should be the same skolem the same, whilst ensuring that typevars that can be different are different.
But what I'm thinking rn is that at every zonking step we return the next available typevar index, and when we learn that we can generalise a univar, we replace that with this next available typevar index. At which point we create the next typevar by incrementing the index of the last one, and bubble that one up the callers.
-/
| typeVar : TypeVar → SimpleZonkedType
| typeCtor : String → List SimpleZonkedType → SimpleZonkedType
abbrev SimpleZonkedTypeMap := HashMap UniVar SimpleZonkedType
def HashMap.attachKeys {k v : Type u} [DecidableEq k] [Hashable k] (map : HashMap k v) : List {key : k // key ∈ map} :=
map.keys.attach.map fun ⟨key, hkey⟩ => ⟨key, HashMap.mem_keys.mp hkey⟩
/-- We go through the map one univar key at a time, and zonk it until it can no longer be zonked. At which point by the end we should have a map of fully zonked types, which should be trivially convertible to a `SimpleZonkedTypeMap` I believe. -/
def zonkSimpleTypeMap (uniVarsMap : SimpleTypeMap) : SimpleZonkedTypeMap :=
let firstTypeVar : TypeVar := .mk 0
let allUniVars := uniVarsMap.1.outerMap |> HashMap.attachKeys
allUniVars.foldl (init := (HashMap.empty, firstTypeVar)) (fun (acc, accTypeVar) ⟨uv, uvInMap⟩ =>
let isValidUniVarKey := uniVarsMap.2 uv uvInMap
let zonked := zonkSingleUv uniVarsMap uv accTypeVar isValidUniVarKey
(acc.insert uv zonked.1, zonked.2)
)
|>.1
where
zonkSingleUv (uniVarsMap : SimpleTypeMap) (uv : UniVar) (nextTypeVar : TypeVar) (uh : IsValidUniVarKey uniVarsMap.1 uv) :
SimpleZonkedType × TypeVar :=
have hmem := uh.mem
match hh : uniVarsMap.1.outerMap[uv]'hmem |>.2 with
| none =>
(SimpleZonkedType.typeVar nextTypeVar, { nextTypeVar with index := nextTypeVar.index + 1 })
| some someVal =>
match hh' : someVal with
| .primitive p => (SimpleZonkedType.primitive p, nextTypeVar)
| .typeCtor ctor uniVars => Id.run do
let mut newList := []
let mut nextTypeVar := nextTypeVar
for hh'' : thisUv in uniVars do
let (zonked, next) :=
by
refine zonkSingleUv uniVarsMap thisUv nextTypeVar ?_
have rewrittenHh := HashMap.getElemPrfFrom2 hh
refine uh.children thisUv ?_
simp [instMembershipUniVarHashSetUniVarOptionSimpleType]
rw [hh]
simp [instMembershipUniVarOptionSimpleType]
exact hh''
newList := zonked :: newList
nextTypeVar := next
return (SimpleZonkedType.typeCtor ctor newList.reverse, nextTypeVar)
termination_by (uniVarsMap, uv)
decreasing_by
rw [true_and]
exists hmem
simp [instMembershipUniVarHashSetUniVarOptionSimpleType]
rw [hh]
simp [instMembershipUniVarOptionSimpleType]
exact hh''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment