Skip to content

Instantly share code, notes, and snippets.

@plaidfinch
Last active November 15, 2015 04:14

Revisions

  1. Kenneth Foner revised this gist Nov 15, 2015. 1 changed file with 7 additions and 6 deletions.
    13 changes: 7 additions & 6 deletions TeachingGADTs.hs
    Original file line number Diff line number Diff line change
    @@ -1,9 +1,10 @@
    {-# LANGUAGE GADTs #-}
    {-# LANGUAGE TypeOperators #-}
    {-# LANGUAGE ConstraintKinds #-}
    {-# LANGUAGE DataKinds #-}
    {-# LANGUAGE TypeFamilies #-}
    {-# LANGUAGE StandaloneDeriving #-}
    {-# LANGUAGE GADTs #-}
    {-# LANGUAGE TypeOperators #-}
    {-# LANGUAGE ConstraintKinds #-}
    {-# LANGUAGE DataKinds #-}
    {-# LANGUAGE TypeFamilies #-}
    {-# LANGUAGE StandaloneDeriving #-}
    {-# LANGUAGE ScopedTypeVariables #-}

    module TeachingGADTs where

  2. Kenneth Foner revised this gist Nov 15, 2015. 1 changed file with 2 additions and 0 deletions.
    2 changes: 2 additions & 0 deletions TeachingGADTs.hs
    Original file line number Diff line number Diff line change
    @@ -235,6 +235,8 @@ unsafeEraseProof :: forall a b. (a :~: b) -> (a :~: b)
    unsafeEraseProof _proof =
    unsafeCoerce Refl :: a :~: b

    -- append two vectors using a type requiring addition to be commutative,
    -- but skip actually running the proof at runtime
    fastHardAppend :: Vec m a -> Vec n a -> Vec (n + m) a
    fastHardAppend v w =
    case unsafeEraseProof (additionCommutative (vecLength v) (vecLength w)) of
  3. Kenneth Foner revised this gist Nov 15, 2015. 1 changed file with 8 additions and 0 deletions.
    8 changes: 8 additions & 0 deletions TeachingGADTs.hs
    Original file line number Diff line number Diff line change
    @@ -261,3 +261,11 @@ slowHardAppendTest (SomeVec x) (SomeVec y) =
    fastHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a
    fastHardAppendTest (SomeVec x) (SomeVec y) =
    SomeVec $ fastHardAppend x y

    -- for instance, try out:
    -- > slowHardAppendTest (toVec [0..3000]) (toVec [0..3000])
    -- > fastHardAppendTest (toVec [0..3000]) (toVec [0..3000])

    -- Notice that there is a noticeable pause before slowHardAppendTest begins printing output.
    -- This is the time it takes to force the thunk which evaluates to the proof of addition being
    -- commutative; that is, this entire time is spent evaluating line 185 of this file.
  4. Kenneth Foner revised this gist Nov 15, 2015. 1 changed file with 45 additions and 0 deletions.
    45 changes: 45 additions & 0 deletions TeachingGADTs.hs
    Original file line number Diff line number Diff line change
    @@ -10,6 +10,8 @@ module TeachingGADTs where
    import Control.Applicative
    import Control.Monad

    import Unsafe.Coerce

    -- Motivation: how would you write a "typed" expression language and its evaluator?
    -- We use GADT *syntax* here but we don't use GADT *semantic* features

    @@ -129,6 +131,8 @@ data Vec (n :: Nat) (a :: *) where
    Nil :: Vec Z a
    Cons :: a -> Vec n a -> Vec (S n) a

    deriving instance Show a => Show (Vec n a)

    -- An excursion: what we would do with DataKinds, but WITHOUT TypeFamilies & GADTs
    -- Here we use phantom types and MODULE ABSTRACTION to manually verify & enforce invariants

    @@ -216,3 +220,44 @@ zipMin :: Vec m a -> Vec n b -> Vec (Min m n) (a, b)
    zipMin Nil _ = Nil
    zipMin _ Nil = Nil
    zipMin (Cons x xs) (Cons y ys) = Cons (x, y) (zipMin xs ys)

    -- Addendum: making things go fast again:

    -- And here's how we can make things go fast, unsafely
    -- This is more or less what a dependently typed language can do in some circumstances,
    -- because if it is total, it knows that it doesn't actually need to run proofs to
    -- make sure they're not bottom.

    -- If there's a runtime-costly proof which you are ABSOLUTELY CERTAIN will never be equal
    -- to bottom (i.e. is the result of a DEFINITELY TOTAL function), you can wrap it in this
    -- function to avoid ever forcing it and doing the extra work to run the proof.
    unsafeEraseProof :: forall a b. (a :~: b) -> (a :~: b)
    unsafeEraseProof _proof =
    unsafeCoerce Refl :: a :~: b

    fastHardAppend :: Vec m a -> Vec n a -> Vec (n + m) a
    fastHardAppend v w =
    case unsafeEraseProof (additionCommutative (vecLength v) (vecLength w)) of
    Refl -> easyAppend v w

    -- existentially quantify over the length of a vector
    data SomeVec a where
    SomeVec :: Vec n a -> SomeVec a
    deriving instance Show a => Show (SomeVec a)

    -- convert a list into an existential-length-ed vector
    toVec :: [a] -> SomeVec a
    toVec [] = SomeVec Nil
    toVec (x : xs) =
    case toVec xs of
    SomeVec xs' -> SomeVec (Cons x xs')

    -- this will run sloooowwwwwly -- O(n^2)
    slowHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a
    slowHardAppendTest (SomeVec x) (SomeVec y) =
    SomeVec $ hardAppend x y

    -- this will run quick -- O(n)
    fastHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a
    fastHardAppendTest (SomeVec x) (SomeVec y) =
    SomeVec $ fastHardAppend x y
  5. Kenneth Foner revised this gist Nov 15, 2015. 1 changed file with 29 additions and 28 deletions.
    57 changes: 29 additions & 28 deletions TeachingGADTs.hs
    Original file line number Diff line number Diff line change
    @@ -7,42 +7,43 @@

    module TeachingGADTs where

    import Prelude hiding ( Maybe(..) )
    import Control.Applicative
    import Control.Monad

    -- Motivation: how would you write a "typed" expression language and its evaluator?
    -- We use GADT *syntax* here but we don't use GADT *semantic* features

    -- introducing GADT syntax: here's how you'd do Maybe
    data Maybe a where
    Just :: a -> Maybe a
    Nothing :: Maybe a
    data Maybe' a where
    Just' :: a -> Maybe' a
    Nothing' :: Maybe' a

    data Expr where
    Plus :: Expr -> Expr -> Expr
    Less :: Expr -> Expr -> Expr
    Int :: Int -> Expr
    Bool :: Bool -> Expr
    IfTe :: Expr -> Expr -> Expr -> Expr
    data Exp where
    PlusExp :: Exp -> Exp -> Exp
    LessExp :: Exp -> Exp -> Exp
    IntExp :: Int -> Exp
    BoolExp :: Bool -> Exp
    IfTeExp :: Exp -> Exp -> Exp -> Exp

    -- type-checking happens at run-time during evaluation, and we can easily get it wrong
    eval :: Expr -> Maybe (Either Int Bool) -- we don't get compile-time type-checking for our DSL
    eval (Plus e1 e2) = do Left i1 <- eval e1
    Left i2 <- eval e2
    pure . Left $ i1 + i2
    eval (Less e1 e2) = do Left i1 <- eval e1
    Left i2 <- eval e2
    pure . Right $ i1 < i2
    eval (Int i) = Just $ Left i
    eval (Bool b) = Just $ Right b
    eval (IfTe e1 e2 e3) = do Right b <- eval e1
    v2 <- eval e2
    v3 <- eval e3
    guard (bothLeft v2 v3 || bothRight v2 v3) -- this line would be easy to forget!
    pure $ if b then v2 else v3
    where bothLeft (Left _) (Left _) = True
    bothLeft _ _ = False
    bothRight (Right _) (Right _) = True
    bothRight _ _ = False
    eval :: Exp -> Maybe (Either Int Bool) -- we don't get compile-time type-checking for our DSL
    eval (PlusExp e1 e2) = do Left i1 <- eval e1
    Left i2 <- eval e2
    pure . Left $ i1 + i2
    eval (LessExp e1 e2) = do Left i1 <- eval e1
    Left i2 <- eval e2
    pure . Right $ i1 < i2
    eval (IntExp i) = Just $ Left i
    eval (BoolExp b) = Just $ Right b
    eval (IfTeExp e1 e2 e3) = do Right b <- eval e1
    v2 <- eval e2
    v3 <- eval e3
    guard (bothLeft v2 v3 || bothRight v2 v3) -- this line would be easy to forget!
    pure $ if b then v2 else v3
    where bothLeft (Left _) (Left _) = True
    bothLeft _ _ = False
    bothRight (Right _) (Right _) = True
    bothRight _ _ = False

    -- A small expression language
    data Expr a where
  6. Kenneth Foner revised this gist Nov 8, 2015. 1 changed file with 4 additions and 2 deletions.
    6 changes: 4 additions & 2 deletions TeachingGADTs.hs
    Original file line number Diff line number Diff line change
    @@ -39,8 +39,10 @@ eval (IfTe e1 e2 e3) = do Right b <- eval e1
    v3 <- eval e3
    guard (bothLeft v2 v3 || bothRight v2 v3) -- this line would be easy to forget!
    pure $ if b then v2 else v3
    where bothLeft (Left _) (Left _) = True
    bothLeft _ _ = False
    where bothLeft (Left _) (Left _) = True
    bothLeft _ _ = False
    bothRight (Right _) (Right _) = True
    bothRight _ _ = False

    -- A small expression language
    data Expr a where
  7. Kenneth Foner created this gist Nov 8, 2015.
    215 changes: 215 additions & 0 deletions TeachingGADTs.hs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,215 @@
    {-# LANGUAGE GADTs #-}
    {-# LANGUAGE TypeOperators #-}
    {-# LANGUAGE ConstraintKinds #-}
    {-# LANGUAGE DataKinds #-}
    {-# LANGUAGE TypeFamilies #-}
    {-# LANGUAGE StandaloneDeriving #-}

    module TeachingGADTs where

    import Prelude hiding ( Maybe(..) )

    -- Motivation: how would you write a "typed" expression language and its evaluator?
    -- We use GADT *syntax* here but we don't use GADT *semantic* features

    -- introducing GADT syntax: here's how you'd do Maybe
    data Maybe a where
    Just :: a -> Maybe a
    Nothing :: Maybe a

    data Expr where
    Plus :: Expr -> Expr -> Expr
    Less :: Expr -> Expr -> Expr
    Int :: Int -> Expr
    Bool :: Bool -> Expr
    IfTe :: Expr -> Expr -> Expr -> Expr

    -- type-checking happens at run-time during evaluation, and we can easily get it wrong
    eval :: Expr -> Maybe (Either Int Bool) -- we don't get compile-time type-checking for our DSL
    eval (Plus e1 e2) = do Left i1 <- eval e1
    Left i2 <- eval e2
    pure . Left $ i1 + i2
    eval (Less e1 e2) = do Left i1 <- eval e1
    Left i2 <- eval e2
    pure . Right $ i1 < i2
    eval (Int i) = Just $ Left i
    eval (Bool b) = Just $ Right b
    eval (IfTe e1 e2 e3) = do Right b <- eval e1
    v2 <- eval e2
    v3 <- eval e3
    guard (bothLeft v2 v3 || bothRight v2 v3) -- this line would be easy to forget!
    pure $ if b then v2 else v3
    where bothLeft (Left _) (Left _) = True
    bothLeft _ _ = False

    -- A small expression language
    data Expr a where
    Plus :: (Num a) => Expr a -> Expr a -> Expr a
    Less :: (Ord a) => Expr a -> Expr a -> Expr Bool
    Lift :: a -> Expr a
    IfTe :: Expr Bool -> Expr a -> Expr a -> Expr a

    -- ... is equivalent to ...

    data Expr' a = forall x. (Num x, a ~ x) => Plus' (Expr' x) (Expr' x)
    | forall x. (Ord x, a ~ Bool) => Less' (Expr' x) (Expr' x)
    | forall x. (x ~ a) => Lift' x
    | forall x b. (x ~ a, b ~ Bool) => IfTe' (Expr' b) (Expr' x) (Expr' x)

    evaluate :: Expr a -> a -- wow! types!
    evaluate expr = case expr of
    Plus e1 e2 ->
    let v1 = evaluate e1
    v2 = evaluate e2
    in v1 + v2
    Less e1 e2 ->
    let v1 = evaluate e1
    v2 = evaluate e2
    in v1 < v2
    Lift e -> e
    IfTe e1 e2 e3 ->
    let v1 = evaluate e1
    v2 = evaluate e2
    v3 = evaluate e3
    in if v1 then v2 else v3

    -- witness of equality proof
    data x :~: y where
    Refl :: x :~: x

    -- discharge an equality proof
    -- notice what happens when you pass undefined
    coerce :: x :~: y -> x -> y
    coerce Refl x = x

    -- witness of a constraint
    data Dict c a where
    Dict :: c a => Dict c a

    -- discharge a constraint proof
    withDict :: Dict c a -> (c a => b) -> b
    withDict Dict x = x

    -- specialized version of Dict
    -- example to "Show" what's going on
    data IsShow a where
    IsShow :: (Show a) => IsShow a

    -- a weird GADT
    data IsInt (x :: Bool) where
    Yep :: (a ~ Int) => a -> IsInt 'True
    Perhaps :: a -> IsInt 'False

    -- this is total!
    getMeAnIntPlease :: IsInt True -> Int
    getMeAnIntPlease (Yep x) = x

    -- ordinary value-level Nat
    data Nat = S Nat | Z

    -- singleton value-level Nat which witnesses a type-level Nat
    -- (we get the type-level Nat via DataKinds)
    data SNat (n :: Nat) where
    SZ :: SNat Z
    SS :: SNat n -> SNat (S n)

    -- addition of SNats which preserves type index properly
    plus :: SNat n -> SNat m -> SNat (n + m)
    plus SZ n = n
    plus (SS n) m = SS (plus n m)

    -- derive instances for GADTs this way
    deriving instance Show (SNat n)

    -- length-indexed list
    data Vec (n :: Nat) (a :: *) where
    Nil :: Vec Z a
    Cons :: a -> Vec n a -> Vec (S n) a

    -- An excursion: what we would do with DataKinds, but WITHOUT TypeFamilies & GADTs
    -- Here we use phantom types and MODULE ABSTRACTION to manually verify & enforce invariants

    data FakeVec (n :: Nat) (a :: *) = FakeVec [a]
    deriving (Read) -- but OH NO we accidentally everything forever and we are now sad

    fakeNil :: FakeVec Z a
    fakeNil = FakeVec []

    fakeCons :: a -> FakeVec n a -> FakeVec (S n) a
    fakeCons a (FakeVec as) = FakeVec (a:as)

    fakeVecSplit :: FakeVec (S n) a -> (a, FakeVec n a)
    fakeVecSplit (FakeVec (a:as)) = (a, FakeVec as)
    fakeVecSplit (FakeVec []) = error "invariant violation: FakeVec is empty!"

    -- We can violate the invariant by reading a string at the "wrong" phantom type!

    -- </excursion>

    -- an example Vec
    vec1 :: Vec (S (S (S Z))) Char
    vec1 = Cons 'A' (Cons 'B' (Cons 'C' Nil))

    -- this is exhaustive (total)!
    zipSame :: Vec n a -> Vec n b -> Vec n (a, b)
    zipSame Nil Nil = Nil
    zipSame (Cons x xs) (Cons y ys) = Cons (x,y) (zipSame xs ys)

    -- length of a vector as a singleton Nat (SNat)
    vecLength :: Vec n a -> SNat n
    vecLength Nil = SZ
    vecLength (Cons x xs) = SS $ vecLength xs

    -- type level addition function (requires TypeFamilies & TypeOperators)
    type family a + b where
    Z + n = n
    S n + m = S (n + m)

    -- GHC will verify this automatically, because it's the exact same recursion pattern as (+)
    easyAppend :: Vec m a -> Vec n a -> Vec (m + n) a
    easyAppend Nil ys = ys
    easyAppend (Cons x xs) ys = Cons x (easyAppend xs ys)

    -- But...
    -- the type level addition '+' is not *automatically* provable
    -- to be commutative, and hence fails to typecheck... unless we PROVE IT using a lemma.
    hardAppend :: Vec m a -> Vec n a -> Vec (n + m) a
    hardAppend v w =
    case additionCommutative (vecLength v) (vecLength w) of
    Refl -> easyAppend v w

    -- sub-lemma: zero is a right neutral for (+)
    rightNeutral :: SNat n -> n :~: (n + Z)
    rightNeutral SZ = Refl
    rightNeutral (SS n) =
    case rightNeutral n of
    Refl -> Refl

    -- sub-lemma: n + S m = S (n + m)
    plusSucc :: SNat n -> SNat m -> (n + S m) :~: S (n + m)
    plusSucc SZ _ = Refl
    plusSucc (SS n) m =
    case plusSucc n m of
    Refl -> Refl

    -- we can use these to prove for any given SNat n, m that addition commutes
    -- note: this is O(n^2) and MUST BE EXECUTED AT RUN-TIME
    additionCommutative :: SNat n -> SNat m -> (n + m) :~: (m + n)
    additionCommutative SZ n =
    case rightNeutral n of Refl -> Refl
    additionCommutative (SS m) n =
    case additionCommutative m n of
    Refl -> case plusSucc n m of
    Refl -> Refl

    -- we can use a type-level minimum function to type-check a truncating zip
    type family Min (m :: Nat) (n :: Nat) where
    Min Z y = Z
    Min x Z = Z
    Min (S x) (S y) = S (Min x y)

    -- truncating zip (ala Haskell's ordinary zip)
    zipMin :: Vec m a -> Vec n b -> Vec (Min m n) (a, b)
    zipMin Nil _ = Nil
    zipMin _ Nil = Nil
    zipMin (Cons x xs) (Cons y ys) = Cons (x, y) (zipMin xs ys)