Skip to content

Instantly share code, notes, and snippets.

@ahaym
Created December 23, 2021 19:07
Show Gist options
  • Save ahaym/f8e07ac833f3a11cf417c7d5fde7e66a to your computer and use it in GitHub Desktop.
Save ahaym/f8e07ac833f3a11cf417c7d5fde7e66a to your computer and use it in GitHub Desktop.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RebindableSyntax #-}
module Circom where
import Prelude
import Control.Monad
import Control.Monad.Trans.State.Strict
import Control.Monad.IO.Class
import qualified Data.Map.Strict as HashMap
import Data.Map.Strict (Map, (!))
type HashMap = Map
data Expr
= Const Integer
| Var String
| Add Expr Expr
| Mul Expr Expr
| Minus Expr Expr
| Div Expr Expr
| Mod Expr Expr
| Ix Expr Integer
| Dot Expr String
deriving Show
instance Num Expr where
(+) = Add
(-) = Minus
(*) = Mul
negate = Minus (Const 0)
abs = id
signum = error "signum unimplemented"
fromInteger = Const
data Privacy = Public | Private deriving Show
data CircomLang
= Assign Expr Expr
| Assert Expr Expr
| Input Privacy String
| Output String
| Signal String
| Component String
| Extern String String [Integer]
deriving Show
data CircomState = CircomState
{ circomOutput :: [CircomLang] -> [CircomLang]
, anonIncr :: Int
, includes :: [FilePath]
, isScoped :: Bool
}
newtype CircomM a = CircomM { unCircomM :: StateT CircomState IO a }
deriving (Functor, Applicative, Monad, MonadIO)
runCircomM :: CircomM () -> IO ([FilePath], [CircomLang])
runCircomM (CircomM st) = do
st <- execStateT st (CircomState id 0 [] False)
return (includes st, (circomOutput st) [])
scoped :: CircomM a -> CircomM a
scoped circom = do
orig <- CircomM get
CircomM $ modify $ \st -> st { isScoped = True }
res <- circom
CircomM $ modify $ \st -> st { isScoped = isScoped orig }
return res
getName :: String -> CircomM String
getName name = do
cur <- CircomM get
if isScoped cur
then do
CircomM $ modify $ \st -> st { anonIncr = 1 + anonIncr st }
return $ name ++ "__temp" ++ show (anonIncr cur)
else return name
addStatement :: CircomLang -> CircomM ()
addStatement lang = CircomM $ modify $ \st -> st { circomOutput = circomOutput st . (lang:) }
addInclude :: FilePath -> CircomM ()
addInclude path = CircomM $ modify $ \st -> st { includes = path : includes st }
emitCircom :: ([FilePath], [CircomLang]) -> [String]
emitCircom (incls, cl) = includeLines ++ ["template Main() {"] ++ map emitCircom1 cl ++ ["}", "component main = Main();"]
where
includeLines = mkIncludeLine <$> incls
mkIncludeLine path = concat ["include", "\"", path, "\"", ";"]
emitCircom1 :: CircomLang -> String
emitCircom1 (Assign lhs rhs) = showExpr lhs ++ " <-- " ++ showExpr rhs ++ ";"
emitCircom1 (Assert lhs rhs) = showExpr lhs ++ " === " ++ showExpr rhs ++ ";"
emitCircom1 (Input privacy name) = unwords ["signal", showPrivacy privacy, "input", name, ";"]
emitCircom1 (Output name) = unwords ["signal", "output", name, ";"]
emitCircom1 (Signal name) = unwords ["signal", name, ";"]
emitCircom1 (Component name) = unwords ["component", name, ";"]
emitCircom1 (Extern func name args) = unwords ["component", name, "=", func, "(", args', ")", ";"]
where
args' = unwords commas
argStrings = map show args
commas = (map (++",") (init argStrings)) ++ [last argStrings]
showPrivacy :: Privacy -> String
showPrivacy Private = "private"
showPrivacy Public = ""
showExpr :: Expr -> String
showExpr (Const n) = show n
showExpr (Var s) = s
showExpr (Add e0 e1) = showBinOp "+" e0 e1
showExpr (Mul e0 e1) = showBinOp "*" e0 e1
showExpr (Minus e0 e1) = showBinOp "-" e0 e1
showExpr (Div e0 e1) = showBinOp "/" e0 e1
showExpr (Mod e0 e1) = showBinOp "%" e0 e1
showExpr (Dot e0 e1) = showExpr e0 ++ "." ++ e1
showExpr (Ix e0 n) = concat [showExpr e0, "[", show n, "]"]
showBinOp :: String -> Expr -> Expr -> String
showBinOp op e0 e1 = concat ["(", showExpr e0, ")", op, "(", showExpr e1, ")"]
private :: Privacy
private = Private
public :: Privacy
public = Public
input :: Privacy -> String -> CircomM Expr
input privacy name = do
name' <- getName name
addStatement $ Input privacy name'
return $ Var name'
output :: String -> CircomM Expr
output name = do
name' <- getName name
addStatement $ Output name'
return $ Var name'
signal :: CircomM Expr
signal = scoped $ do
name' <- getName "sig"
addStatement $ Signal name'
return $ Var name'
component :: CircomM Expr
component = scoped $ do
name' <- getName "cm"
addStatement $ Component name'
return $ Var name'
mkArray :: (String -> CircomLang) -> String -> Integer -> CircomM [Expr]
mkArray fn name size = do
name' <- getName name
addStatement $ fn $ showExpr $ Ix (Var name') size
return $ map (Ix (Var name')) [0..size]
inputArray :: Privacy -> String -> Integer -> CircomM [Expr]
inputArray privacy = mkArray $ Input privacy
outputArray :: String -> Integer -> CircomM [Expr]
outputArray = mkArray Output
signalArray :: Integer -> CircomM [Expr]
signalArray = scoped . mkArray Signal "sig"
componentArray :: Integer -> CircomM [Expr]
componentArray = scoped . mkArray Component "cm"
(===) :: Expr -> Expr -> CircomM ()
(===) e1 e2 = addStatement $ Assert e1 e2
infix 4 ===
(<--) :: Expr -> Expr -> CircomM ()
(<--) e1 e2 = addStatement $ Assign e1 e2
infix 4 <--
(<==) :: Expr -> Expr -> CircomM ()
(<==) e1 e2 = do
e1 <-- e2
e1 === e2
infix 4 <==
class IfThenElse a b where
ifThenElse :: a -> b -> b -> b
instance IfThenElse Bool a where
ifThenElse cond a b = case cond of
True -> a
False -> b
instance IfThenElse Expr Expr where
-- cond should be 1 or 0
ifThenElse cond a b = (b-a)*cond + a
runTest :: CircomM () -> IO ()
runTest circom = do
out <- runCircomM circom
putStrLn $ unlines $ emitCircom out
compile :: CircomM () -> FilePath -> IO ()
compile circom fp = do
out <- runCircomM circom
writeFile fp $ unlines $ emitCircom out
mimcSponge :: [Expr] -> Integer -> Expr -> Integer -> CircomM [Expr]
mimcSponge ins s k outN = scoped $ do
name <- getName "hasher"
let inN = fromIntegral $ length ins
addStatement $ Extern "MiMCSponge" name [inN, s, outN]
forM (zip [0..] ins) $ \(ix, inp) -> Ix (Dot (Var name) "ins") ix <== inp
Dot (Var name) "k" <== k
return $ map (Ix (Dot (Var name) "outs")) [0..outN]
-- Demo Starts Here ==============================================================
-- Functions
hashLeftRight l r = do
out <- mimcSponge [l, r] 220 0 1
return $ out !! 0
-- Returns [l, r] if s == 0, otherwise [r, l]
dualMux :: Expr -> Expr -> Expr -> CircomM (HashMap String Expr)
dualMux l r s = do
first <- signal
second <- signal
first <== if s then l else r
second <== if s then r else l
-- Arbitrary Data Structures
return $ HashMap.fromList [("first", first), ("second", second)]
test = do
-- IO
levels <- liftIO $ do
putStrLn "How many levels?"
read <$> getLine
addInclude "../node_modules/circomlib/circuits/mimcsponge.circom"
leaf <- input private "leaf"
root <- input public "root"
pathElements <- inputArray private "pathElements" levels
pathIndices <- inputArray private "pathIndices" levels
let
-- Functions
loop prevHash i = scoped $ do
hashMap <- dualMux prevHash (pathElements !! i) (pathIndices !! i)
hashLeftRight (hashMap ! "first") (hashMap ! "second")
endHash <- foldM loop leaf [0..fromIntegral levels-1]
root === endHash
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment