Skip to content

Instantly share code, notes, and snippets.

@vollmerm
Last active September 16, 2016 18:48
Show Gist options
  • Save vollmerm/53befef5ab2d0a4be471f2cc33b6cee5 to your computer and use it in GitHub Desktop.
Save vollmerm/53befef5ab2d0a4be471f2cc33b6cee5 to your computer and use it in GitHub Desktop.
Type-level array lengths
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Data.SizedArray where
import GHC.TypeLits
import Unsafe.Coerce
import Data.Proxy
import Data.Vector
import Prelude hiding (concat)
-- parameterize a vector by a type-level nat
newtype SArr (n :: Nat) (a :: *)
= SArr { getVector :: Vector a }
-- define accessor functions
arrNew :: [a] -> SNat n -> SArr n a
arrNew ls n = if (snatToInteger n) == (toInteger $ Prelude.length ls)
then SArr $ fromList ls
else error "Invalid list length"
arrRef :: ((m + 1) <= n) => SArr n a -> SNat m -> a
arrRef arr n = (getVector arr) ! (fromIntegral $ snatToInteger n)
arrConcat :: SArr m a -> SArr n a -> SArr (m + n) a
arrConcat arr brr = SArr $ concat [getVector arr, getVector brr]
-- test it:
test1 =
let a = arrNew [1,2,3] (snat :: SNat 3)
i = snat :: SNat 1
in arrRef a i -- 2
test2 =
let a = arrNew [1,2,3] (snat :: SNat 4) -- Exception: Invalid list lenght
i = snat :: SNat 1
in arrRef a i
-- test2 =
-- let a = arrNew [1,2,3] (snat :: SNat 3)
-- i = snat :: SNat 3
-- in arrRef a i -- fails with "expected 'True, got (3 + 1) <=? 3"
-- Nat stuff
data SNat (n :: Nat) = KnownNat n => SNat (Proxy n)
instance Show (SNat n) where
show (SNat p) = 'd' : show (natVal p)
{-# INLINE snat #-}
-- | Create a singleton literal for a type-level natural number
snat :: KnownNat n => SNat n
snat = SNat Proxy
{-# INLINE withSNat #-}
-- | Supply a function with a singleton natural 'n' according to the context
withSNat :: KnownNat n => (SNat n -> a) -> a
withSNat f = f (SNat Proxy)
{-# INLINE snatToInteger #-}
snatToInteger :: SNat n -> Integer
snatToInteger (SNat p) = natVal p
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment