Created
March 24, 2016 19:06
-
-
Save raboof/6c83bee675f328fb04a9 to your computer and use it in GitHub Desktop.
Small Scala macros example: compile-time calculation of factorials
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object CompileTimeFactorial { | |
import scala.language.experimental.macros | |
// This function exposed to consumers has a normal Scala type: | |
def factorial(n: Int): Int = | |
// but it is implemented as a macro: | |
macro factorial_impl | |
import scala.reflect.macros.blackbox.Context | |
// The macro implementation will receive a 'Context' and | |
// the AST's of the parameters passed to it: | |
def factorial_impl(c: Context)(n: c.Expr[Int]): c.Expr[Int] = { | |
import c.universe._ | |
// We can pattern-match on the AST: | |
n match { | |
case Expr(Literal(Constant(nValue: Int))) => | |
// We perform the calculation: | |
val result = normalFactorial(parsed_n) | |
// And produce an AST for the result of the computation: | |
c.Expr(Literal(Constant(result))) | |
case other => | |
// Yes, this will be printed at compile time: | |
println("Yow!") | |
??? | |
} | |
} | |
// The actual implementation is regular old-fashioned scala code: | |
private def normalFactorial(n: Int): Int = | |
if (n < 0) 0 - normalFactorial(-n) | |
else if (n < 2) 1 | |
else n * normalFactorial(n - 1) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment