Last active
March 18, 2025 10:28
-
-
Save Arrow7000/97f5880818c6bfb12a16269278839374 to your computer and use it in GitHub Desktop.
Very simple type universe and zonking implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Std.Data.HashMap | |
import Std.Data.HashSet | |
import Mathlib | |
open Std | |
instance {k} [BEq k] [Hashable k] : Hashable (HashSet k) where | |
hash := hash ∘ HashSet.toArray | |
def Std.HashSet.intersection {k : Type u} [BEq k] [Hashable k] (a : HashSet k) (b : HashSet k) : HashSet k := | |
a.fold (init := ∅) fun s x => if b.contains x then s.insert x else s | |
def Std.HashSet.intersection_mem_both {k : Type u} [BEq k] [Hashable k] [EquivBEq k] [LawfulHashable k] (a : HashSet k) (b : HashSet k) : ∀ v ∈ (a.intersection b), v ∈ a ∧ v ∈ b := by | |
rw [intersection] | |
suffices ∀ a : List k, ∀ q : HashSet k, ∀ (v : k), v ∈ List.foldl (fun s x => if b.contains x = true then s.insert x else s) q a → v ∈ q ∨ a.contains v ∧ v ∈ b by | |
simpa [HashSet.fold_eq_foldl_toList, HashSet.mem_iff_contains] using this a.toList ∅ | |
intro a q | |
induction a generalizing q with | |
| nil => simp | |
| cons hd tl ih => | |
simp | |
intro v hv | |
obtain (h|h) := ih _ _ hv | |
· split at h | |
· next ht => | |
rw [HashSet.mem_insert] at h | |
obtain (h|h) := h | |
· refine Or.inr ⟨Or.inl (BEq.symm h), ?_⟩ | |
rwa [HashSet.mem_iff_contains, ← HashSet.contains_congr h] | |
· exact Or.inl h | |
· exact Or.inl h | |
· exact Or.inr ⟨Or.inr h.1, h.2⟩ | |
/-- Gets the items in `a` that are not in `b` -/ | |
def Std.HashSet.diff {k : Type u} [BEq k] [Hashable k] (a : HashSet k) (b : HashSet k) : HashSet k := | |
a.filter (not ∘ b.contains) | |
/-- | |
A disjoint-set map where: | |
- sub-keys are of type k | |
- keys are of type Set k | |
- values are of type Set k → v – i.e. values are dependent on their keys | |
- union combines two keys into a single key, and therefore requires you to pass a value merging function `(newSet : Set k) → v → v → v` | |
- findKeys takes a sub-key and returns the full key it is a member of | |
- findValueByKey takes a full set key and returns the value associated with that key | |
- findValue takes a sub-key and returns the value associated with the full key it is a member of | |
- the way to represent this is probably to have a map of sub-keys pointing to the full set key, which in another map points to the value | |
-/ | |
structure DisjointSetMap (k : Type u) [BEq k] [Hashable k] (v : Type u) where | |
/-- This maps from _sets_ of `k`s to `v`s -/ | |
innerMap : HashMap (HashSet k) v | |
/-- How to merge values together when key sets are merged -/ | |
mergeFn : v → v → (keysNew : HashSet k) → v | |
namespace DisjointSetMap | |
variable {k : Type u} [instBeq : BEq k] [instHash : Hashable k] {v : Type u} | |
private def getOuterMapFromInnerMap {k v : Type u} [BEq k] [Hashable k] (innerMap : HashMap (HashSet k) v) : HashMap k (HashSet k × v) := | |
innerMap | |
|>.fold | |
(fun (map : HashMap k (HashSet k × v)) (keySet : HashSet k) value => | |
keySet.fold (fun map' key => map'.insert key (keySet, value)) map) | |
HashMap.empty | |
/-- A computed map from the inner map to a map of individual keys to key sets _and_ values; so we can: | |
a) find items in the map by a single key by doing `d.outerMap[key]` | |
b) which then returns the full set of keys that the single key belongs to, as well as the value stored for that set: `HashSet k × v` | |
-/ | |
def outerMap (d : DisjointSetMap k v) : HashMap k (HashSet k × v) := | |
getOuterMapFromInnerMap d.innerMap | |
def empty (mergeFn : v → v → (keysNew : HashSet k) → v) : DisjointSetMap k v := | |
{ innerMap := HashMap.empty | |
mergeFn } | |
def addSet (d : DisjointSetMap k v) (newKeySet : HashSet k) (val : v) : DisjointSetMap k v := | |
let overlappingSets := | |
d.innerMap.fold (init := (newKeySet, [])) fun acc currSet value => | |
let intersection := acc.1.intersection currSet | |
if intersection == ∅ then | |
-- The no overlap case, so there's nothing to add here | |
acc | |
else | |
-- If there is some overlap, we snowball the current set in the accumulated set, and include the value of the newly merged set in the list of values to merge later | |
let union := acc.1.union currSet | |
(union, value :: acc.2) | |
match overlappingSets with | |
| (_, []) => | |
-- No overlaps, just insert the new set with its value | |
{ d with innerMap := d.innerMap.insert newKeySet val } | |
| (mergedSet, valuesToMerge) => | |
-- Merge all overlapping sets and their values | |
let mergedValue := | |
valuesToMerge.foldl (init := val) | |
(fun acc value => d.mergeFn acc value mergedSet) | |
-- Remove old sets | |
let newInnerMap := | |
d.innerMap.fold (init := HashMap.empty) fun newInnerMap currSet _ => | |
let hasOverlap := currSet.any newKeySet.contains | |
if hasOverlap then | |
newInnerMap.erase currSet | |
else | |
newInnerMap | |
-- And insert the new snowballed merged set with its combined value | |
{ d with innerMap := newInnerMap.insert mergedSet mergedValue } | |
/-- This merges multiple sets without adding a new value. If none of the keys are in the map then this is a no-op because we have no value to set for it! -/ | |
def union (d : DisjointSetMap k v) (keysToMerge : HashSet k) : DisjointSetMap k v := | |
let overlappingSets := | |
d.innerMap.fold (init := (keysToMerge, [])) fun acc currSet value => | |
let intersection := acc.1.intersection currSet | |
if intersection == ∅ then | |
-- The no overlap case, so there's nothing to add here | |
acc | |
else | |
-- If there is some overlap, we snowball the current set in the accumulated set, and include the value of the newly merged set in the list of values to merge later | |
let union := acc.1.union currSet | |
(union, value :: acc.2) | |
match overlappingSets with | |
| (_, []) => | |
-- No overlaps, none of the keys are in the map so we do nothing because we have no value to set for it | |
d | |
| (mergedSet, firstVal :: restValsToMerge) => | |
-- Merge all overlapping sets and their values | |
let mergedValue := | |
restValsToMerge.foldl (init := firstVal) | |
(fun acc value => d.mergeFn acc value mergedSet) | |
-- Remove old sets | |
let newInnerMap := | |
d.innerMap.fold (init := HashMap.empty) fun newInnerMap currSet _ => | |
let hasOverlap := currSet.any keysToMerge.contains | |
if hasOverlap then | |
newInnerMap.erase currSet | |
else | |
newInnerMap | |
-- And insert the new snowballed merged set with its combined value | |
{ d with innerMap := newInnerMap.insert mergedSet mergedValue } | |
def find? (d : DisjointSetMap k v) (key : k) : Option v := | |
d.outerMap[key]? |>.map (fun (_, v) => v) | |
def find (d : DisjointSetMap k v) (key : k) (h : key ∈ d.outerMap) : v := | |
d.outerMap[key]'h |>.2 | |
end DisjointSetMap |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Std.Data.HashMap | |
import DisjointSetMap -- make sure this points to the right namespace for the module! | |
open Std | |
open DisjointSetMap | |
/-- A type clash between two types that are not compatible. -/ | |
structure Incompatible (t : Type u) where | |
t1 : t | |
t2 : t | |
inductive Primitive where | |
| unit | |
| string | |
| int | |
| bool | |
| nat | |
deriving DecidableEq | |
structure TypeVar where | |
index : Nat | |
deriving DecidableEq | |
structure UnificationVar where | |
index : Nat | |
deriving Hashable, DecidableEq | |
abbrev UniVar := UnificationVar | |
/-- A type that does away with the frills, and is purpose built from the ground up for how we do inference: namely it only represents primitives, type constructors, and anything beyond that is a univar that references an `Option SimpleType`. in the UniVarsMap. -/ | |
inductive SimpleType where | |
| primitive : Primitive → SimpleType | |
/-- A type name with the type variables -/ | |
| typeCtor : String → List UniVar → SimpleType | |
instance instMembershipUniVarSimpleType : Membership UniVar SimpleType where | |
mem container item := | |
match container with | |
| .primitive _ => False | |
| .typeCtor _ uniVars => uniVars.contains item | |
instance instMembershipUniVarPrimitive : Membership UniVar Primitive where | |
mem _ _ := False | |
instance {a} : Membership UniVar (DisjointSetMap UniVar a) where | |
mem container item := | |
container.outerMap.contains item | |
instance instMembershipUniVarOptionSimpleType : Membership UniVar (Option SimpleType) where | |
mem container item := | |
match container with | |
| .some t => item ∈ t | |
| .none => False | |
instance instMembershipUniVarHashSetUniVarOptionSimpleType : Membership UniVar (HashSet UniVar × Option SimpleType) where | |
mem container item := instMembershipUniVarOptionSimpleType.mem container.2 item | |
abbrev SimpleTypeMap' := DisjointSetMap UniVar (Option SimpleType) | |
/-- The inductive predicate that a uniVar is a valid key in a `SimpleTypeMap'`, that every univar at a value in the map is also a valid key in the map, and that there are no cycles of univars looping round to point to each other! -/ | |
inductive IsValidUniVarKey (m : SimpleTypeMap') : UniVar → Prop where | |
| mk (uniVar : UniVar) | |
(hmem : uniVar ∈ m) | |
(h : ∀ (uv : UniVar), uv ∈ m.outerMap[uniVar]'hmem → IsValidUniVarKey m uv) : IsValidUniVarKey m uniVar | |
theorem IsValidUniVarKey.mem (h : IsValidUniVarKey m u) : u ∈ m := | |
match h with | |
| .mk _ mem _ => mem | |
theorem IsValidUniVarKey.children {m u} (h : IsValidUniVarKey m u) : ∀ (uv : UniVar), uv ∈ m.outerMap[u]'h.mem → IsValidUniVarKey m uv := | |
match h with | |
| .mk _ _ children => children | |
def SimpleTypeMap := { map : SimpleTypeMap' // ∀ uv ∈ map.outerMap, IsValidUniVarKey map uv } | |
theorem accessible (m : SimpleTypeMap) (u : UniVar) (h : IsValidUniVarKey m.1 u) : | |
Acc (fun x y => x.1 = y.1 ∧ ∃ h : y.2 ∈ y.1.1, x.2 ∈ y.1.1.outerMap[y.2]'h) (m, u) := by | |
induction h with | mk u mem h ih => | |
constructor | |
intro (m', u') | |
rintro ⟨rfl, mem, ht⟩ | |
dsimp only at h ih mem ht | |
exact ih u' ht | |
instance : WellFoundedRelation (SimpleTypeMap × UniVar) where | |
-- x "<" y | |
rel x y := x.1 = y.1 ∧ ∃ h : y.2 ∈ y.1.1, x.2 ∈ y.1.1.outerMap[y.2]'h | |
wf := by | |
constructor | |
intro a | |
constructor | |
intro (m, u) ⟨h₁, h₂, ht⟩ | |
have h := a.1.2 a.2 h₂ | |
apply accessible | |
cases h₁ | |
exact h.children u ht | |
@[simp] theorem SimpleType.uniVar_in_typeCtor {uv : UniVar} {ctor : String} {uniVars : List UniVar} : (uv ∈ SimpleType.typeCtor ctor uniVars) = (uv ∈ uniVars) := by | |
simp [instMembershipUniVarSimpleType] | |
theorem HashMap.getElemPrfFrom2 {k v1 v2 : Type u} {someVal : v2} [DecidableEq k] [Hashable k] {key : k} {map : HashMap k (v1 × v2)} {h : key ∈ map} (_ : map[key].2 = someVal) : key ∈ map := by | |
exact h | |
/- ZONKING -/ | |
inductive SimpleZonkedType where | |
| primitive : Primitive → SimpleZonkedType | |
/-- A type variable. Might be slightly tricky to make sure we keep typevars that should be the same skolem the same, whilst ensuring that typevars that can be different are different. | |
But what I'm thinking rn is that at every zonking step we return the next available typevar index, and when we learn that we can generalise a univar, we replace that with this next available typevar index. At which point we create the next typevar by incrementing the index of the last one, and bubble that one up the callers. | |
-/ | |
| typeVar : TypeVar → SimpleZonkedType | |
| typeCtor : String → List SimpleZonkedType → SimpleZonkedType | |
abbrev SimpleZonkedTypeMap := HashMap UniVar SimpleZonkedType | |
def HashMap.attachKeys {k v : Type u} [DecidableEq k] [Hashable k] (map : HashMap k v) : List {key : k // key ∈ map} := | |
map.keys.attach.map fun ⟨key, hkey⟩ => ⟨key, HashMap.mem_keys.mp hkey⟩ | |
/-- We go through the map one univar key at a time, and zonk it until it can no longer be zonked. At which point by the end we should have a map of fully zonked types, which should be trivially convertible to a `SimpleZonkedTypeMap` I believe. -/ | |
def zonkSimpleTypeMap (uniVarsMap : SimpleTypeMap) : SimpleZonkedTypeMap := | |
let firstTypeVar : TypeVar := .mk 0 | |
let allUniVars := uniVarsMap.1.outerMap |> HashMap.attachKeys | |
allUniVars.foldl (init := (HashMap.empty, firstTypeVar)) (fun (acc, accTypeVar) ⟨uv, uvInMap⟩ => | |
let isValidUniVarKey := uniVarsMap.2 uv uvInMap | |
let zonked := zonkSingleUv uniVarsMap uv accTypeVar isValidUniVarKey | |
(acc.insert uv zonked.1, zonked.2) | |
) | |
|>.1 | |
where | |
zonkSingleUv (uniVarsMap : SimpleTypeMap) (uv : UniVar) (nextTypeVar : TypeVar) (uh : IsValidUniVarKey uniVarsMap.1 uv) : | |
SimpleZonkedType × TypeVar := | |
have hmem := uh.mem | |
match hh : uniVarsMap.1.outerMap[uv]'hmem |>.2 with | |
| none => | |
(SimpleZonkedType.typeVar nextTypeVar, { nextTypeVar with index := nextTypeVar.index + 1 }) | |
| some someVal => | |
match hh' : someVal with | |
| .primitive p => (SimpleZonkedType.primitive p, nextTypeVar) | |
| .typeCtor ctor uniVars => Id.run do | |
let mut newList := [] | |
let mut nextTypeVar := nextTypeVar | |
for hh'' : thisUv in uniVars do | |
let (zonked, next) := | |
by | |
refine zonkSingleUv uniVarsMap thisUv nextTypeVar ?_ | |
have rewrittenHh := HashMap.getElemPrfFrom2 hh | |
refine uh.children thisUv ?_ | |
simp [instMembershipUniVarHashSetUniVarOptionSimpleType] | |
rw [hh] | |
simp [instMembershipUniVarOptionSimpleType] | |
exact hh'' | |
newList := zonked :: newList | |
nextTypeVar := next | |
return (SimpleZonkedType.typeCtor ctor newList.reverse, nextTypeVar) | |
termination_by (uniVarsMap, uv) | |
decreasing_by | |
rw [true_and] | |
exists hmem | |
simp [instMembershipUniVarHashSetUniVarOptionSimpleType] | |
rw [hh] | |
simp [instMembershipUniVarOptionSimpleType] | |
exact hh'' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment