Skip to content

Instantly share code, notes, and snippets.

@plaidfinch
Last active November 15, 2015 04:14
Show Gist options
  • Save plaidfinch/1b4e227e476353e775fe to your computer and use it in GitHub Desktop.
Save plaidfinch/1b4e227e476353e775fe to your computer and use it in GitHub Desktop.
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE StandaloneDeriving #-}
module TeachingGADTs where
import Control.Applicative
import Control.Monad
import Unsafe.Coerce
-- Motivation: how would you write a "typed" expression language and its evaluator?
-- We use GADT *syntax* here but we don't use GADT *semantic* features
-- introducing GADT syntax: here's how you'd do Maybe
data Maybe' a where
Just' :: a -> Maybe' a
Nothing' :: Maybe' a
data Exp where
PlusExp :: Exp -> Exp -> Exp
LessExp :: Exp -> Exp -> Exp
IntExp :: Int -> Exp
BoolExp :: Bool -> Exp
IfTeExp :: Exp -> Exp -> Exp -> Exp
-- type-checking happens at run-time during evaluation, and we can easily get it wrong
eval :: Exp -> Maybe (Either Int Bool) -- we don't get compile-time type-checking for our DSL
eval (PlusExp e1 e2) = do Left i1 <- eval e1
Left i2 <- eval e2
pure . Left $ i1 + i2
eval (LessExp e1 e2) = do Left i1 <- eval e1
Left i2 <- eval e2
pure . Right $ i1 < i2
eval (IntExp i) = Just $ Left i
eval (BoolExp b) = Just $ Right b
eval (IfTeExp e1 e2 e3) = do Right b <- eval e1
v2 <- eval e2
v3 <- eval e3
guard (bothLeft v2 v3 || bothRight v2 v3) -- this line would be easy to forget!
pure $ if b then v2 else v3
where bothLeft (Left _) (Left _) = True
bothLeft _ _ = False
bothRight (Right _) (Right _) = True
bothRight _ _ = False
-- A small expression language
data Expr a where
Plus :: (Num a) => Expr a -> Expr a -> Expr a
Less :: (Ord a) => Expr a -> Expr a -> Expr Bool
Lift :: a -> Expr a
IfTe :: Expr Bool -> Expr a -> Expr a -> Expr a
-- ... is equivalent to ...
data Expr' a = forall x. (Num x, a ~ x) => Plus' (Expr' x) (Expr' x)
| forall x. (Ord x, a ~ Bool) => Less' (Expr' x) (Expr' x)
| forall x. (x ~ a) => Lift' x
| forall x b. (x ~ a, b ~ Bool) => IfTe' (Expr' b) (Expr' x) (Expr' x)
evaluate :: Expr a -> a -- wow! types!
evaluate expr = case expr of
Plus e1 e2 ->
let v1 = evaluate e1
v2 = evaluate e2
in v1 + v2
Less e1 e2 ->
let v1 = evaluate e1
v2 = evaluate e2
in v1 < v2
Lift e -> e
IfTe e1 e2 e3 ->
let v1 = evaluate e1
v2 = evaluate e2
v3 = evaluate e3
in if v1 then v2 else v3
-- witness of equality proof
data x :~: y where
Refl :: x :~: x
-- discharge an equality proof
-- notice what happens when you pass undefined
coerce :: x :~: y -> x -> y
coerce Refl x = x
-- witness of a constraint
data Dict c a where
Dict :: c a => Dict c a
-- discharge a constraint proof
withDict :: Dict c a -> (c a => b) -> b
withDict Dict x = x
-- specialized version of Dict
-- example to "Show" what's going on
data IsShow a where
IsShow :: (Show a) => IsShow a
-- a weird GADT
data IsInt (x :: Bool) where
Yep :: (a ~ Int) => a -> IsInt 'True
Perhaps :: a -> IsInt 'False
-- this is total!
getMeAnIntPlease :: IsInt True -> Int
getMeAnIntPlease (Yep x) = x
-- ordinary value-level Nat
data Nat = S Nat | Z
-- singleton value-level Nat which witnesses a type-level Nat
-- (we get the type-level Nat via DataKinds)
data SNat (n :: Nat) where
SZ :: SNat Z
SS :: SNat n -> SNat (S n)
-- addition of SNats which preserves type index properly
plus :: SNat n -> SNat m -> SNat (n + m)
plus SZ n = n
plus (SS n) m = SS (plus n m)
-- derive instances for GADTs this way
deriving instance Show (SNat n)
-- length-indexed list
data Vec (n :: Nat) (a :: *) where
Nil :: Vec Z a
Cons :: a -> Vec n a -> Vec (S n) a
deriving instance Show a => Show (Vec n a)
-- An excursion: what we would do with DataKinds, but WITHOUT TypeFamilies & GADTs
-- Here we use phantom types and MODULE ABSTRACTION to manually verify & enforce invariants
data FakeVec (n :: Nat) (a :: *) = FakeVec [a]
deriving (Read) -- but OH NO we accidentally everything forever and we are now sad
fakeNil :: FakeVec Z a
fakeNil = FakeVec []
fakeCons :: a -> FakeVec n a -> FakeVec (S n) a
fakeCons a (FakeVec as) = FakeVec (a:as)
fakeVecSplit :: FakeVec (S n) a -> (a, FakeVec n a)
fakeVecSplit (FakeVec (a:as)) = (a, FakeVec as)
fakeVecSplit (FakeVec []) = error "invariant violation: FakeVec is empty!"
-- We can violate the invariant by reading a string at the "wrong" phantom type!
-- </excursion>
-- an example Vec
vec1 :: Vec (S (S (S Z))) Char
vec1 = Cons 'A' (Cons 'B' (Cons 'C' Nil))
-- this is exhaustive (total)!
zipSame :: Vec n a -> Vec n b -> Vec n (a, b)
zipSame Nil Nil = Nil
zipSame (Cons x xs) (Cons y ys) = Cons (x,y) (zipSame xs ys)
-- length of a vector as a singleton Nat (SNat)
vecLength :: Vec n a -> SNat n
vecLength Nil = SZ
vecLength (Cons x xs) = SS $ vecLength xs
-- type level addition function (requires TypeFamilies & TypeOperators)
type family a + b where
Z + n = n
S n + m = S (n + m)
-- GHC will verify this automatically, because it's the exact same recursion pattern as (+)
easyAppend :: Vec m a -> Vec n a -> Vec (m + n) a
easyAppend Nil ys = ys
easyAppend (Cons x xs) ys = Cons x (easyAppend xs ys)
-- But...
-- the type level addition '+' is not *automatically* provable
-- to be commutative, and hence fails to typecheck... unless we PROVE IT using a lemma.
hardAppend :: Vec m a -> Vec n a -> Vec (n + m) a
hardAppend v w =
case additionCommutative (vecLength v) (vecLength w) of
Refl -> easyAppend v w
-- sub-lemma: zero is a right neutral for (+)
rightNeutral :: SNat n -> n :~: (n + Z)
rightNeutral SZ = Refl
rightNeutral (SS n) =
case rightNeutral n of
Refl -> Refl
-- sub-lemma: n + S m = S (n + m)
plusSucc :: SNat n -> SNat m -> (n + S m) :~: S (n + m)
plusSucc SZ _ = Refl
plusSucc (SS n) m =
case plusSucc n m of
Refl -> Refl
-- we can use these to prove for any given SNat n, m that addition commutes
-- note: this is O(n^2) and MUST BE EXECUTED AT RUN-TIME
additionCommutative :: SNat n -> SNat m -> (n + m) :~: (m + n)
additionCommutative SZ n =
case rightNeutral n of Refl -> Refl
additionCommutative (SS m) n =
case additionCommutative m n of
Refl -> case plusSucc n m of
Refl -> Refl
-- we can use a type-level minimum function to type-check a truncating zip
type family Min (m :: Nat) (n :: Nat) where
Min Z y = Z
Min x Z = Z
Min (S x) (S y) = S (Min x y)
-- truncating zip (ala Haskell's ordinary zip)
zipMin :: Vec m a -> Vec n b -> Vec (Min m n) (a, b)
zipMin Nil _ = Nil
zipMin _ Nil = Nil
zipMin (Cons x xs) (Cons y ys) = Cons (x, y) (zipMin xs ys)
-- Addendum: making things go fast again:
-- And here's how we can make things go fast, unsafely
-- This is more or less what a dependently typed language can do in some circumstances,
-- because if it is total, it knows that it doesn't actually need to run proofs to
-- make sure they're not bottom.
-- If there's a runtime-costly proof which you are ABSOLUTELY CERTAIN will never be equal
-- to bottom (i.e. is the result of a DEFINITELY TOTAL function), you can wrap it in this
-- function to avoid ever forcing it and doing the extra work to run the proof.
unsafeEraseProof :: forall a b. (a :~: b) -> (a :~: b)
unsafeEraseProof _proof =
unsafeCoerce Refl :: a :~: b
fastHardAppend :: Vec m a -> Vec n a -> Vec (n + m) a
fastHardAppend v w =
case unsafeEraseProof (additionCommutative (vecLength v) (vecLength w)) of
Refl -> easyAppend v w
-- existentially quantify over the length of a vector
data SomeVec a where
SomeVec :: Vec n a -> SomeVec a
deriving instance Show a => Show (SomeVec a)
-- convert a list into an existential-length-ed vector
toVec :: [a] -> SomeVec a
toVec [] = SomeVec Nil
toVec (x : xs) =
case toVec xs of
SomeVec xs' -> SomeVec (Cons x xs')
-- this will run sloooowwwwwly -- O(n^2)
slowHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a
slowHardAppendTest (SomeVec x) (SomeVec y) =
SomeVec $ hardAppend x y
-- this will run quick -- O(n)
fastHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a
fastHardAppendTest (SomeVec x) (SomeVec y) =
SomeVec $ fastHardAppend x y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment