Last active
May 27, 2020 00:49
-
-
Save felko/90dfeecfd2795652b8902f6169285481 to your computer and use it in GitHub Desktop.
linear lambda calculus typechecker
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
executable llc | |
main-is: Main.hs | |
-- other-modules: | |
build-depends: base >=4.12 && <4.13 | |
, containers >=0.6.2.1 && <0.7 | |
, mtl >=2.2 && <2.3 | |
, uuid | |
, MonadRandom | |
, these | |
, semialign | |
, semialign-indexed | |
, pretty | |
hs-source-dirs: src | |
default-language: Haskell2010 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE | |
LambdaCase | |
, OverloadedLists | |
, OverloadedStrings | |
, RecordWildCards | |
, BlockArguments | |
, DeriveFunctor | |
, TypeApplications | |
, GeneralizedNewtypeDeriving | |
#-} | |
-- https://core.ac.uk/download/pdf/81933277.pdf | |
module Main where | |
import Control.Arrow ((>>>)) | |
import Control.Monad.Identity | |
import Control.Monad.State | |
import Control.Monad.Except | |
import Control.Monad.RWS | |
import Control.Monad.Random | |
import Data.Functor | |
import Data.Monoid | |
import Data.Maybe | |
import Data.Function | |
import Data.List (nub, intercalate) | |
import Data.List.NonEmpty (NonEmpty(..)) | |
import qualified Data.List.NonEmpty as NE | |
import qualified Data.Set as Set | |
import qualified Data.Map as Map | |
import Data.These | |
import Data.Semialign hiding (zip) | |
import Data.Semialign.Indexed | |
import Text.PrettyPrint.HughesPJClass hiding ((<>)) | |
import Data.UUID hiding (null) | |
import Data.Coerce | |
import Debug.Trace | |
data Name = Name | |
{ display :: String | |
, uid :: UUID } | |
deriving Show | |
instance Eq Name where | |
(==) = (==) `on` uid | |
instance Ord Name where | |
compare = compare `on` uid | |
instance Pretty Name where | |
pPrintPrec (PrettyLevel 0) _ Name{..} = text display | |
pPrintPrec _ _ Name{..} = text display <> "@" <> text (show uid) | |
data Type | |
= VarT String | |
| AppT Type Type | |
| TensorT Type Type | |
| PlusT Type Type | |
| WithT Type Type | |
| LolliT Type Type | |
| UnitT | |
| OfCourseT Type | |
deriving (Eq, Show) | |
instance Pretty Type where | |
pPrintPrec l i = \case | |
VarT n -> text n | |
TensorT a b -> maybeParens (i >= 3) (pPrintPrec l 3 a <+> "⊗" <+> pPrintPrec l 3 b) | |
LolliT a b -> maybeParens (i >= 1) (pPrintPrec l 1 a <+> "⊸" <+> pPrintPrec l 0 b) | |
PlusT a b -> maybeParens (i >= 3) (pPrintPrec l 3 a <+> "⊕" <+> pPrintPrec l 3 b) | |
WithT a b -> maybeParens (i >= 3) (pPrintPrec l 3 a <+> "&" <+> pPrintPrec l 3 b) | |
AppT a b -> maybeParens (i >= 2) (pPrintPrec l 1 a <+> pPrintPrec l 3 b) | |
OfCourseT a -> "!" <> pPrintPrec l 3 a | |
UnitT -> "1" | |
data Term | |
= Var String -- x | |
| Let String Term Term -- let y = 2 * x in y^8 | |
| Unit -- ⋆ | |
| Empty String Term -- empty x, e | |
| App Term Term -- f x | |
| Abs String Term -- λ x ⊸ e | |
| Pair Term Term -- ⟨x, y⟩ | |
| Choose String String Bool Term -- choose x = tup.fst, e | |
| Tensor Term Term -- x ⊗ y | |
| Split String String String Term -- split x ⊗ y = t, e | |
| Quote Term -- `e` | |
| Eval String String Term -- eval x = u, e | |
| Copy String String String Term -- copy (x, y) = u, e | |
| Ignore String Term -- ignore x, e | |
| Inl Term | Inr Term | |
| Case String String Term String Term | |
-- | Cons String Term -- <Just 1> | |
-- | Case Term [(String, String, Term)] -- case x of { <Failure err> e | <Success res> f } | |
deriving Show | |
instance Pretty Term where | |
pPrintPrec lvl i = \case | |
Var n -> text n | |
Let x y e -> maybeParens (i >= 1) ("let " <> text x <> " = " <> pPrintPrec lvl 0 y <> ", " <> pPrintPrec lvl 0 e) | |
Unit -> "⋆" | |
Empty x e -> maybeParens (i >= 1) ("empty " <> text x <> ", " <> pPrintPrec lvl 0 e) | |
Abs x e -> maybeParens (i >= 1) ("λ " <> text x <> " ⊸ " <> pPrintPrec lvl 0 e) | |
App f x -> maybeParens (i >= 2) (pPrintPrec lvl 1 f <+> pPrintPrec lvl 1 x) | |
Pair x y -> "⟨" <> pPrintPrec lvl 0 x <> ", " <> pPrintPrec lvl 0 y <> "⟩" | |
Choose x p False e -> maybeParens (i >= 1) ("choose " <> text x <> " = " <> text p <> ".fst, " <> pPrintPrec lvl 0 e) | |
Choose y p True e -> maybeParens (i >= 1) ("choose " <> text y <> " = " <> text p <> ".snd, " <> pPrintPrec lvl 0 e) | |
Tensor x y -> maybeParens (i >= 1) (pPrintPrec lvl 1 x <> " ⊗ " <> pPrintPrec lvl 1 y) | |
Split x y z e -> maybeParens (i >= 1) ("split " <> text x <> " ⊗ " <> text y <> " = " <> text z <> ", " <> pPrintPrec lvl 0 e) | |
Quote x -> "`" <> pPrintPrec lvl 0 x <> "`" | |
Eval x u e -> maybeParens (i >= 1) ("eval " <> text x <> " = " <> text u <> ", " <> pPrintPrec lvl 0 e) | |
Copy x y z e -> maybeParens (i >= 1) ("copy " <> parens (text x <> ", " <> text y) <> " = " <> text z <> ", " <> pPrintPrec lvl 0 e) | |
Ignore x e -> maybeParens (i >= 1) ("ignore " <> text x <> ", " <> pPrintPrec lvl 0 e) | |
Inl x -> "<Left " <> pPrintPrec lvl 2 x <> ">" | |
Inr x -> "<Right " <> pPrintPrec lvl 2 x <> ">" | |
Case e x p y q -> maybeParens (i >= 1) ("case " <> text e <> braces ("Left " <> text x <> " → " <> pPrintPrec lvl 0 p <> " | " <> "Right " <> text y <> " → " <> pPrintPrec lvl 0 q)) | |
type Scope = Map.Map String Name | |
newtype Context = Context | |
{ getCtx :: Map.Map Name Type } | |
deriving (Show, Semigroup, Monoid) | |
instance Pretty Context where | |
pPrintPrec lvl _ (Context m) = cat . punctuate ", " $ pAssoc <$> Map.assocs m | |
where pAssoc (n, t) = pPrintPrec lvl 0 n <+> ":" <+> pPrint t | |
introduce :: Name -> Type -> Context -> Context | |
introduce n t (Context ctx) = Context (Map.insert n t ctx) | |
consume :: Name -> Context -> Context | |
consume n (Context ctx) = Context (Map.delete n ctx) | |
data CheckError | |
= ScopeError String | |
| UnboundError String | |
| OverlapError Context | |
| TypeError Type Type | |
| UnusedError Context | |
| OccursCheckError String Type | |
deriving Show | |
data CheckState = CheckState | |
{ tyVarSupply :: Int } | |
deriving Show | |
data Constraint | |
= Type :~ Type | |
deriving Show | |
data Judgement = Context :⊢ Type | |
deriving Show | |
instance Pretty Judgement where | |
pPrintPrec lvl _ (ctx :⊢ t) = pPrintPrec lvl 0 ctx <+> "⊢" <+> pPrint t | |
type Check a = | |
RWST | |
(Map.Map String Name) | |
[Constraint] | |
CheckState | |
(RandT StdGen (Except (Last (NonEmpty CheckError)))) | |
a | |
lookupCtx :: Name -> Context -> Check Type | |
lookupCtx n (Context ctx) = maybe (checkError (ScopeError (display n))) pure (Map.lookup n ctx) | |
mergeCtx :: Context -> Context -> Check Context | |
mergeCtx (Context ctx) (Context ctx') | |
| Map.disjoint ctx ctx' = pure (Context (Map.union ctx ctx')) | |
| otherwise = checkError (OverlapError (Context (Map.intersection ctx ctx'))) | |
unifyCtx :: Context -> Context -> Check Context | |
unifyCtx (Context ctx) (Context ctx') = Context <$> sequence (ialignWith f ctx ctx') | |
where f n (These a b) = require (a :~ b) *> pure a | |
f n (This a) = checkError (ScopeError (display n)) | |
f n (That a) = checkError (ScopeError (display n)) | |
unrestrictedCtx :: Context -> Check () | |
unrestrictedCtx = getCtx >>> mapM_ \ t -> do | |
t' <- freshTyVar | |
require (t :~ OfCourseT t') | |
checkError :: CheckError -> Check a | |
checkError err = throwError (pure [err]) | |
require :: Constraint -> Check () | |
require c = tell [c] | |
unique :: String -> Check Name | |
unique s = Name s <$> getRandom @_ @UUID | |
bound :: String -> Check Name | |
bound s = asks (Map.lookup s) >>= \case | |
Just n -> pure n | |
Nothing -> checkError (ScopeError s) | |
freshTyVar :: Check Type | |
freshTyVar = gets tyVarSupply >>= \ i -> do | |
modify \ st -> st { tyVarSupply = i + 1 } | |
pure (VarT ('$':show i)) | |
debug :: (Term -> Check Judgement) -> Term -> Check Judgement | |
debug chk term = chk term >>= \ j@(ctx :⊢ t) -> | |
trace (render (pPrint ctx <+> "⊢" <+> pPrint term <+> ":" <+> pPrint t)) $ pure j | |
check :: Term -> Check Judgement | |
check = debug \case | |
Var s -> do | |
n <- bound s | |
t <- freshTyVar | |
pure (Context (Map.singleton n t) :⊢ t) | |
Let x y e -> do | |
xn <- unique x | |
ctxy :⊢ a <- check y | |
ctxe :⊢ t <- local (Map.insert x xn) (check e) | |
a' <- lookupCtx xn ctxe | |
require (a :~ a') | |
ctx <- mergeCtx ctxy (consume xn ctxe) | |
pure (ctx :⊢ t) | |
Unit -> pure (mempty :⊢ UnitT) | |
Empty x e -> do | |
xn <- bound x | |
ctx :⊢ t <- check e | |
pure (introduce xn UnitT ctx :⊢ t) | |
Tensor x y -> do | |
ctx1 :⊢ a <- check x | |
ctx2 :⊢ b <- check y | |
ctx <- mergeCtx ctx1 ctx2 | |
pure (ctx :⊢ TensorT a b) | |
Split x y z e -> do | |
(xn, yn) <- (,) <$> unique x <*> unique y | |
ctx :⊢ t <- local (Map.insert x xn . Map.insert y yn) (check e) | |
(a, b) <- (,) <$> lookupCtx xn ctx <*> lookupCtx yn ctx | |
zn <- bound z | |
let upd = introduce zn (TensorT a b) . consume xn . consume yn | |
pure (upd ctx :⊢ t) | |
Abs x e -> do | |
xn <- unique x | |
ctx :⊢ b <- local (Map.insert x xn) (check e) | |
a <- lookupCtx xn ctx | |
pure (consume xn ctx :⊢ LolliT a b) | |
App f x -> do | |
(ctxf :⊢ tf, ctxx :⊢ tx) <- (,) <$> check f <*> check x | |
ty <- freshTyVar | |
require (tf :~ LolliT tx ty) | |
ctx <- mergeCtx ctxf ctxx | |
pure (ctx :⊢ ty) | |
Pair x y -> do | |
(ctxx :⊢ tx, ctxy :⊢ ty) <- (,) <$> check x <*> check y | |
ctx <- unifyCtx ctxx ctxy | |
pure (ctx :⊢ WithT tx ty) | |
Choose y p False e -> do | |
yn <- unique y | |
pn <- bound p | |
ctx :⊢ t <- local (Map.insert y yn) (check e) | |
a <- lookupCtx yn ctx | |
b <- freshTyVar | |
pure (introduce pn (WithT a b) (consume yn ctx) :⊢ t) | |
Choose x p True e -> do | |
xn <- unique x | |
pn <- bound p | |
ctx :⊢ t <- local (Map.insert x xn) (check e) | |
a <- freshTyVar | |
b <- lookupCtx xn ctx | |
pure (introduce pn (WithT a b) (consume xn ctx) :⊢ t) | |
Inl x -> do | |
ctx :⊢ a <- check x | |
b <- freshTyVar | |
pure (ctx :⊢ PlusT a b) | |
Inr y -> do | |
ctx :⊢ b <- check y | |
a <- freshTyVar | |
pure (ctx :⊢ PlusT a b) | |
Case x l p r q -> do | |
xn <- bound x | |
(ln, rn) <- (,) <$> unique l <*> unique r | |
ctxp :⊢ tp <- local (Map.insert l ln) (check p) | |
ctxq :⊢ tq <- local (Map.insert r rn) (check q) | |
require (tp :~ tq) | |
(a, b) <- (,) <$> lookupCtx ln ctxp <*> lookupCtx rn ctxq | |
ctx <- unifyCtx (consume ln ctxp) (consume rn ctxq) | |
pure (introduce xn (PlusT a b) ctx :⊢ tp) | |
Copy x y z e -> do | |
zn <- bound z | |
(xn, yn) <- (,) <$> unique x <*> unique y | |
ctx :⊢ t <- local (Map.insert x xn . Map.insert y yn) (check e) | |
(ax, ay) <- (,) <$> lookupCtx xn ctx <*> lookupCtx yn ctx | |
bangA <- OfCourseT <$> freshTyVar | |
require (bangA :~ ax) | |
require (bangA :~ ay) | |
let upd = introduce zn bangA . consume xn . consume yn | |
pure (upd ctx :⊢ t) | |
Quote e -> do | |
ctx :⊢ t <- check e | |
unrestrictedCtx ctx | |
pure (ctx :⊢ OfCourseT t) | |
Eval x u e -> do | |
xn <- unique x | |
un <- bound u | |
ctx :⊢ t <- local (Map.insert x xn) (check e) | |
a <- lookupCtx xn ctx | |
pure (introduce un (OfCourseT a) (consume xn ctx) :⊢ t) | |
runCheck :: Term -> ExceptT [CheckError] IO (Type, [Constraint]) | |
runCheck term = do | |
g <- liftIO newStdGen | |
let tr (Identity (Left (Last Nothing))) = pure (Left []) | |
tr (Identity (Left (Last (Just errs)))) = pure (Left (NE.toList errs)) | |
tr (Identity (Right res)) = pure (Right res) | |
initialState = CheckState { tyVarSupply = 0 } | |
(Context ctx :⊢ typ, cs) <- mapExceptT tr (evalRandT (evalRWST (check term) mempty initialState) g) | |
if Map.null ctx then | |
pure (typ, cs) | |
else | |
throwError (UnboundError . display <$> Map.keys ctx) | |
newtype Subst = Subst | |
{ getSubst :: Map.Map String Type } | |
deriving Show | |
instance Pretty Subst where | |
pPrint (Subst m) = braces . cat . punctuate ", " $ pAssoc <$> Map.assocs m | |
where pAssoc (v, t) = text v <+> "⇒" <+> pPrint t | |
instance Semigroup Subst where | |
Subst s1 <> Subst s2 = Subst (Map.map (substitute (Subst s2)) s1 <> s2) | |
instance Monoid Subst where | |
mempty = Subst mempty | |
freeTyVars :: Type -> [String] | |
freeTyVars = nub . go | |
where go = \case | |
VarT n -> [n] :: [String] | |
TensorT a b -> go a <> go b | |
LolliT a b -> go a <> go b | |
PlusT a b -> go a <> go b | |
WithT a b -> go a <> go b | |
OfCourseT a -> go a | |
AppT a b -> go a <> go b | |
UnitT -> [] | |
substitute :: Subst -> Type -> Type | |
substitute s@(Subst m) = \case | |
VarT n -> fromMaybe (VarT n) (Map.lookup n m) | |
TensorT a b -> TensorT (substitute s a) (substitute s b) | |
LolliT a b -> LolliT (substitute s a) (substitute s b) | |
PlusT a b -> PlusT (substitute s a) (substitute s b) | |
WithT a b -> WithT (substitute s a) (substitute s b) | |
OfCourseT a -> OfCourseT (substitute s a) | |
UnitT -> UnitT | |
AppT a b -> AppT (substitute s a) (substitute s b) | |
subst1 :: String -> Type -> Either CheckError Subst | |
subst1 n t | |
| n `elem` freeTyVars t = throwError (OccursCheckError n t) | |
| otherwise = pure (Subst (Map.singleton n t)) | |
unify :: Type -> Type -> Either CheckError Subst | |
unify = curry \case | |
(VarT n, b) -> subst1 n b | |
(a, VarT n) -> subst1 n a | |
(TensorT a b, TensorT a' b') -> (<>) <$> unify a a' <*> unify b b' | |
(LolliT a b, LolliT a' b') -> (<>) <$> unify a a' <*> unify b b' | |
(PlusT a b, PlusT a' b') -> (<>) <$> unify a a' <*> unify b b' | |
(WithT a b, WithT a' b') -> (<>) <$> unify a a' <*> unify b b' | |
(AppT a b, AppT a' b') -> (<>) <$> unify a a' <*> unify b b' | |
(OfCourseT a, OfCourseT b) -> unify a b | |
(UnitT, UnitT) -> pure mempty | |
(a, b) -> throwError (TypeError a b) | |
mergeSubsts :: Subst -> Subst -> Either CheckError Subst | |
mergeSubsts s1@(Subst m1) s2@(Subst m2) = mappend (s1 <> s2) . foldMap id <$> (sequence (Map.intersectionWith unify m1 m2)) | |
solve :: [Constraint] -> ([CheckError], Subst) | |
solve cs = foldr merge ([], mempty) solutions | |
where solutions = cs <&> \case | |
a :~ b -> unify a b | |
merge (Left err) (errs, subst) = (err : errs, subst) | |
merge (Right s) (errs, subst) = case mergeSubsts s subst of | |
Left err -> (err : errs, subst <> s) | |
Right s' -> (errs, s') | |
infer :: Term -> ExceptT [CheckError] IO Type | |
infer term = do | |
(t, cs) <- runCheck term | |
let (errs, subst) = solve cs | |
-- liftIO (putStrLn (prettyShow t) >> print subst >> putStrLn (prettyShow (substitute subst t))) | |
if null errs then | |
let t' = substitute subst t | |
tvs = freeTyVars t' | |
prettyVars = (:[]) <$> ['A'..'Z'] | |
subst' = Subst (Map.fromList (zip tvs (VarT <$> prettyVars))) | |
in pure (substitute subst' t') | |
else | |
throwError errs | |
tensorAssoc :: Term | |
tensorAssoc = Abs "xy_z" (Split "xy" "z" "xy_z" | |
(Split "x" "y" "xy" (Tensor (Var "x") | |
(Tensor (Var "y") (Var "z"))))) | |
boolIndex :: Term | |
boolIndex = Abs "p" (Abs "b" (Case "b" | |
"u" (Empty "u" (Choose "x" "p" False (Var "x"))) | |
"u" (Empty "u" (Choose "x" "p" True (Var "x"))))) | |
exponentialMap :: Term | |
exponentialMap = | |
Abs "u" (Copy "x" "y" "u" (Tensor | |
(Quote (Eval "p" "x" (Choose "a" "p" False (Var "a")))) | |
(Quote (Eval "q" "y" (Choose "b" "q" True (Var "b")))))) | |
{- | |
A ⊗ (B ⅋ C) ⊸ ((A ⊗ B) ⅋ C) | |
= A ⊗ (~B ⊸ C) ⊸ (A ⊸ ~B) ⊸ C | |
= A ⊗ (B ⊸ C) ⊸ (A ⊸ B) ⊸ C | |
-} | |
linearDistribution :: Term | |
linearDistribution = | |
Abs "af" (Abs "g" | |
(Split "a" "f" "af" (App (Var "f") (App (Var "g") (Var "a"))))) | |
twice :: Term | |
twice = Abs "f" (Abs "x" | |
(Copy "f1" "f2" "f" | |
(Eval "f1e" "f1" | |
(Eval "f2e" "f2" | |
(App (Var "f2e") | |
(App (Var "f1e") (Var "x"))))))) | |
{-} | |
testLet = Abs "x" (Abs "f" | |
(Copy "f1" "f2" "f" (Eval "f1e" "f1" (Eval "f2e" "f2" (Let "y" (App (Var "f1e") (Var "x")) (App (Var "f2e") (Var "y"))))))) | |
-} | |
test :: Term -> IO () | |
test term = runExceptT (infer term) >>= \case | |
Left errs -> mapM_ print errs | |
Right typ -> putStrLn (prettyShow typ) | |
main :: IO () | |
main = do | |
mapM_ @[] test [tensorAssoc, boolIndex, exponentialMap, linearDistribution, twice] | |
-- print (mergeSubsts (Subst (Map.fromList [("$0", LolliT (VarT "$1") (VarT "$2"))])) (Subst (Map.fromList [("$0", LolliT UnitT (OfCourseT (VarT "$3")))]))) | |
{- | |
⊗-assoc : ∀ A B C. (A ⊗ B) ⊗ C ⊸ A ⊗ (B ⊗ C) | |
⊗-assoc = λ xy_z ⊸ split (xy, z) = xy_z, | |
split (x, y) = xy, | |
x ⊗ (y ⊗ z) | |
bool-index : ∀ A. A & A ⊸ 1 ⊕ 1 ⊸ A | |
bool-index = λ p ⊸ λ b ⊸ case b of | |
true -> choose x = p.fst, x | |
false -> choose y = p.snd, y | |
exponentialMap : ∀ A B. !(A & B) ⊸ !A ⊗ !B | |
exponentialMap = λ u ⊸ | |
copy (x, y) = u, `eval p = x, choose a = p.fst, a` | |
⊗ `eval q = y, choose b = p.snd, b` | |
linearDistribution : ∀ A B C. A ⊗ (B ⊸ C) ⊸ (A ⊸ B) ⊸ C | |
linearDistribution = λ af ⊸ λ g ⊸ split (a, f) = af, f (g a) | |
-} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment