Created
December 23, 2021 19:07
-
-
Save ahaym/f8e07ac833f3a11cf417c7d5fde7e66a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE ExistentialQuantification #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE GeneralizedNewtypeDeriving #-} | |
{-# LANGUAGE KindSignatures #-} | |
{-# LANGUAGE MultiParamTypeClasses #-} | |
{-# LANGUAGE RebindableSyntax #-} | |
module Circom where | |
import Prelude | |
import Control.Monad | |
import Control.Monad.Trans.State.Strict | |
import Control.Monad.IO.Class | |
import qualified Data.Map.Strict as HashMap | |
import Data.Map.Strict (Map, (!)) | |
type HashMap = Map | |
data Expr | |
= Const Integer | |
| Var String | |
| Add Expr Expr | |
| Mul Expr Expr | |
| Minus Expr Expr | |
| Div Expr Expr | |
| Mod Expr Expr | |
| Ix Expr Integer | |
| Dot Expr String | |
deriving Show | |
instance Num Expr where | |
(+) = Add | |
(-) = Minus | |
(*) = Mul | |
negate = Minus (Const 0) | |
abs = id | |
signum = error "signum unimplemented" | |
fromInteger = Const | |
data Privacy = Public | Private deriving Show | |
data CircomLang | |
= Assign Expr Expr | |
| Assert Expr Expr | |
| Input Privacy String | |
| Output String | |
| Signal String | |
| Component String | |
| Extern String String [Integer] | |
deriving Show | |
data CircomState = CircomState | |
{ circomOutput :: [CircomLang] -> [CircomLang] | |
, anonIncr :: Int | |
, includes :: [FilePath] | |
, isScoped :: Bool | |
} | |
newtype CircomM a = CircomM { unCircomM :: StateT CircomState IO a } | |
deriving (Functor, Applicative, Monad, MonadIO) | |
runCircomM :: CircomM () -> IO ([FilePath], [CircomLang]) | |
runCircomM (CircomM st) = do | |
st <- execStateT st (CircomState id 0 [] False) | |
return (includes st, (circomOutput st) []) | |
scoped :: CircomM a -> CircomM a | |
scoped circom = do | |
orig <- CircomM get | |
CircomM $ modify $ \st -> st { isScoped = True } | |
res <- circom | |
CircomM $ modify $ \st -> st { isScoped = isScoped orig } | |
return res | |
getName :: String -> CircomM String | |
getName name = do | |
cur <- CircomM get | |
if isScoped cur | |
then do | |
CircomM $ modify $ \st -> st { anonIncr = 1 + anonIncr st } | |
return $ name ++ "__temp" ++ show (anonIncr cur) | |
else return name | |
addStatement :: CircomLang -> CircomM () | |
addStatement lang = CircomM $ modify $ \st -> st { circomOutput = circomOutput st . (lang:) } | |
addInclude :: FilePath -> CircomM () | |
addInclude path = CircomM $ modify $ \st -> st { includes = path : includes st } | |
emitCircom :: ([FilePath], [CircomLang]) -> [String] | |
emitCircom (incls, cl) = includeLines ++ ["template Main() {"] ++ map emitCircom1 cl ++ ["}", "component main = Main();"] | |
where | |
includeLines = mkIncludeLine <$> incls | |
mkIncludeLine path = concat ["include", "\"", path, "\"", ";"] | |
emitCircom1 :: CircomLang -> String | |
emitCircom1 (Assign lhs rhs) = showExpr lhs ++ " <-- " ++ showExpr rhs ++ ";" | |
emitCircom1 (Assert lhs rhs) = showExpr lhs ++ " === " ++ showExpr rhs ++ ";" | |
emitCircom1 (Input privacy name) = unwords ["signal", showPrivacy privacy, "input", name, ";"] | |
emitCircom1 (Output name) = unwords ["signal", "output", name, ";"] | |
emitCircom1 (Signal name) = unwords ["signal", name, ";"] | |
emitCircom1 (Component name) = unwords ["component", name, ";"] | |
emitCircom1 (Extern func name args) = unwords ["component", name, "=", func, "(", args', ")", ";"] | |
where | |
args' = unwords commas | |
argStrings = map show args | |
commas = (map (++",") (init argStrings)) ++ [last argStrings] | |
showPrivacy :: Privacy -> String | |
showPrivacy Private = "private" | |
showPrivacy Public = "" | |
showExpr :: Expr -> String | |
showExpr (Const n) = show n | |
showExpr (Var s) = s | |
showExpr (Add e0 e1) = showBinOp "+" e0 e1 | |
showExpr (Mul e0 e1) = showBinOp "*" e0 e1 | |
showExpr (Minus e0 e1) = showBinOp "-" e0 e1 | |
showExpr (Div e0 e1) = showBinOp "/" e0 e1 | |
showExpr (Mod e0 e1) = showBinOp "%" e0 e1 | |
showExpr (Dot e0 e1) = showExpr e0 ++ "." ++ e1 | |
showExpr (Ix e0 n) = concat [showExpr e0, "[", show n, "]"] | |
showBinOp :: String -> Expr -> Expr -> String | |
showBinOp op e0 e1 = concat ["(", showExpr e0, ")", op, "(", showExpr e1, ")"] | |
private :: Privacy | |
private = Private | |
public :: Privacy | |
public = Public | |
input :: Privacy -> String -> CircomM Expr | |
input privacy name = do | |
name' <- getName name | |
addStatement $ Input privacy name' | |
return $ Var name' | |
output :: String -> CircomM Expr | |
output name = do | |
name' <- getName name | |
addStatement $ Output name' | |
return $ Var name' | |
signal :: CircomM Expr | |
signal = scoped $ do | |
name' <- getName "sig" | |
addStatement $ Signal name' | |
return $ Var name' | |
component :: CircomM Expr | |
component = scoped $ do | |
name' <- getName "cm" | |
addStatement $ Component name' | |
return $ Var name' | |
mkArray :: (String -> CircomLang) -> String -> Integer -> CircomM [Expr] | |
mkArray fn name size = do | |
name' <- getName name | |
addStatement $ fn $ showExpr $ Ix (Var name') size | |
return $ map (Ix (Var name')) [0..size] | |
inputArray :: Privacy -> String -> Integer -> CircomM [Expr] | |
inputArray privacy = mkArray $ Input privacy | |
outputArray :: String -> Integer -> CircomM [Expr] | |
outputArray = mkArray Output | |
signalArray :: Integer -> CircomM [Expr] | |
signalArray = scoped . mkArray Signal "sig" | |
componentArray :: Integer -> CircomM [Expr] | |
componentArray = scoped . mkArray Component "cm" | |
(===) :: Expr -> Expr -> CircomM () | |
(===) e1 e2 = addStatement $ Assert e1 e2 | |
infix 4 === | |
(<--) :: Expr -> Expr -> CircomM () | |
(<--) e1 e2 = addStatement $ Assign e1 e2 | |
infix 4 <-- | |
(<==) :: Expr -> Expr -> CircomM () | |
(<==) e1 e2 = do | |
e1 <-- e2 | |
e1 === e2 | |
infix 4 <== | |
class IfThenElse a b where | |
ifThenElse :: a -> b -> b -> b | |
instance IfThenElse Bool a where | |
ifThenElse cond a b = case cond of | |
True -> a | |
False -> b | |
instance IfThenElse Expr Expr where | |
-- cond should be 1 or 0 | |
ifThenElse cond a b = (b-a)*cond + a | |
runTest :: CircomM () -> IO () | |
runTest circom = do | |
out <- runCircomM circom | |
putStrLn $ unlines $ emitCircom out | |
compile :: CircomM () -> FilePath -> IO () | |
compile circom fp = do | |
out <- runCircomM circom | |
writeFile fp $ unlines $ emitCircom out | |
mimcSponge :: [Expr] -> Integer -> Expr -> Integer -> CircomM [Expr] | |
mimcSponge ins s k outN = scoped $ do | |
name <- getName "hasher" | |
let inN = fromIntegral $ length ins | |
addStatement $ Extern "MiMCSponge" name [inN, s, outN] | |
forM (zip [0..] ins) $ \(ix, inp) -> Ix (Dot (Var name) "ins") ix <== inp | |
Dot (Var name) "k" <== k | |
return $ map (Ix (Dot (Var name) "outs")) [0..outN] | |
-- Demo Starts Here ============================================================== | |
-- Functions | |
hashLeftRight l r = do | |
out <- mimcSponge [l, r] 220 0 1 | |
return $ out !! 0 | |
-- Returns [l, r] if s == 0, otherwise [r, l] | |
dualMux :: Expr -> Expr -> Expr -> CircomM (HashMap String Expr) | |
dualMux l r s = do | |
first <- signal | |
second <- signal | |
first <== if s then l else r | |
second <== if s then r else l | |
-- Arbitrary Data Structures | |
return $ HashMap.fromList [("first", first), ("second", second)] | |
test = do | |
-- IO | |
levels <- liftIO $ do | |
putStrLn "How many levels?" | |
read <$> getLine | |
addInclude "../node_modules/circomlib/circuits/mimcsponge.circom" | |
leaf <- input private "leaf" | |
root <- input public "root" | |
pathElements <- inputArray private "pathElements" levels | |
pathIndices <- inputArray private "pathIndices" levels | |
let | |
-- Functions | |
loop prevHash i = scoped $ do | |
hashMap <- dualMux prevHash (pathElements !! i) (pathIndices !! i) | |
hashLeftRight (hashMap ! "first") (hashMap ! "second") | |
endHash <- foldM loop leaf [0..fromIntegral levels-1] | |
root === endHash |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment