Created
February 11, 2016 10:50
-
-
Save LPGhatguy/1f87d0fb51a91628e184 to your computer and use it in GitHub Desktop.
A parser for C-style numeric expressions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local bit = require("bit") | |
-- Operators that can stay as-is | |
local natop = { | |
["+"] = true, | |
["-"] = true, | |
["*"] = true, | |
["/"] = true | |
} | |
-- Operators we translate to functions | |
local transop = { | |
["&"] = "band", | |
["|"] = "bor", | |
["<<"] = "lshift", | |
[">>"] = "rshift", | |
["^"] = "bxor", | |
["~"] = "bnot" | |
} | |
-- The number of characters the parser should look ahead | |
local elookahead = 1 | |
for k in pairs(transop) do | |
if (#k > elookahead) then | |
elookahead = #k | |
end | |
end | |
-- Generate a baseline AST | |
local function baseAST(exp) | |
local ast = {} | |
local i = 1 | |
while (i <= #exp) do | |
local char = exp:sub(i, i) | |
if (char == "(") then | |
local level = 1 | |
for j = i + 1, #exp do | |
local jc = exp:sub(j, j) | |
if (jc == "(") then | |
level = level + 1 | |
elseif (jc == ")") then | |
level = level - 1 | |
end | |
if (level == 0) then | |
table.insert(ast, {"group", baseAST(exp:sub(i + 1, j - 1))}) | |
i = j + 1 | |
break | |
end | |
end | |
elseif (natop[char]) then | |
table.insert(ast, {"op", char}) | |
elseif (tonumber(char)) then | |
local buf = char | |
local k = i - 1 | |
for j = i + 1, #exp do | |
local char = exp:sub(j, j) | |
if (tonumber(char)) then | |
buf = buf .. char | |
k = i | |
else | |
break | |
end | |
end | |
i = k + 1 | |
table.insert(ast, {"val", tonumber(buf)}) | |
elseif (transop[char]) then | |
table.insert(ast, {"transop", transop[char]}) | |
else | |
for j = i + 1, i + elookahead - 1 do | |
local piece = exp:sub(i, j) | |
if (transop[piece]) then | |
table.insert(ast, {"transop", transop[piece]}) | |
i = j + 1 | |
break | |
end | |
end | |
end | |
i = i + 1 | |
end | |
return ast | |
end | |
-- Transform certain operators to function calls | |
local function normalizeAST(ast) | |
local newAST = {} | |
local i = 1 | |
while (i <= #ast) do | |
local v = ast[i] | |
local nv = ast[i + 1] | |
local nnv = ast[i + 2] | |
if (nv and nv[1] == "transop") then | |
local f = nv[2] | |
local a = v | |
local b = nnv | |
table.insert(newAST, {"call", f, a, b}) | |
i = i + 2 | |
elseif (v[1] == "group") then | |
table.insert(newAST, {"group", normalizeAST(v[2])}) | |
else | |
table.insert(newAST, v) | |
end | |
i = i + 1 | |
end | |
return newAST | |
end | |
-- Get an AST from a function | |
local function astify(exp) | |
local ast = baseAST(exp) | |
return normalizeAST(ast) | |
end | |
local deastify | |
local function showNode(node) | |
if (node[1] == "group") then | |
return "(" .. deastify(node[2]) .. ")" | |
elseif (node[1] == "op" or node[1] == "val") then | |
return tostring(node[2]) | |
elseif (node[1] == "call") then | |
return node[2] .. "(" .. showNode(node[3]) .. ", " .. showNode(node[4]) .. ")" | |
else | |
print("unknown", node[1]) | |
return "" | |
end | |
end | |
-- Generate source from an AST | |
function deastify(ast) | |
local chunks = {} | |
for i, v in ipairs(ast) do | |
table.insert(chunks, showNode(v)) | |
end | |
local source = table.concat(chunks, " ") | |
return source | |
end | |
-- Generate source from an expression | |
local function sourceify(exp) | |
local ast = astify(exp) | |
return deastify(ast) | |
end | |
-- Generate a function from an expression | |
local function functate(exp) | |
local source = sourceify(exp) | |
local f, err = loadstring("return function() return (" .. source .. ") end") | |
if (not f) then | |
print("Error generating function:", err) | |
return | |
end | |
local env = {} | |
for k, v in pairs(transop) do | |
env[v] = bit[v] | |
end | |
setfenv(f, env) | |
return f() | |
end | |
-- Evaluate an expression | |
local function eval(exp) | |
local f = functate(exp) | |
if (f) then | |
return f(exp) | |
end | |
end | |
-- View an AST | |
local function viewAST(ast, level) | |
local buffer = {} | |
level = level or 0 | |
for i, v in ipairs(ast) do | |
if (v[1] == "group") then | |
table.insert(buffer, ("\t"):rep(level) .. i .. " group:\n" .. viewAST(v[2], level + 1)) | |
elseif (v[1] == "call") then | |
local a, b = v[3], v[4] | |
local aV, bV | |
if (a[1] == "group") then | |
aV = "group: \n" .. viewAST(a[2], level + 1) .. "\n" .. ("\t"):rep(level) | |
else | |
aV = tostring(a[2]) | |
end | |
if (a[2] == "group") then | |
bV = "\n" .. viewAST(b[2], level + 1) | |
else | |
bV = tostring(b[2]) | |
end | |
table.insert(buffer, ("\t"):rep(level) .. i .. " call: " .. v[2] .. "(" .. aV .. ", " .. bV .. ")") | |
else | |
table.insert(buffer, ("\t"):rep(level) .. i .. " " .. v[1] .. ": " .. v[2]) | |
end | |
end | |
return table.concat(buffer, "\n") | |
end | |
local function assertEq(exp, v) | |
local got = eval(exp) or "nil" | |
if (got == v) then | |
return | |
end | |
print("Failure: " .. exp) | |
print("", "Got", got, "but expected", v) | |
print("AST:") | |
print(viewAST(astify(exp), 1)) | |
end | |
local function test() | |
assertEq("8 << 1", 16) | |
assertEq("8 >> 1", 4) | |
assertEq("2 | 4", 6) | |
assertEq("6 & 4", 4) | |
assertEq("17 ^ 1", 16) | |
assertEq("1 << 30", 2^30) | |
assertEq("(2 + 2) << 1", 8) | |
assertEq("(((2 + 2)))", 4) | |
end | |
local function demo() | |
local src = "(1 + 1) * 3" | |
print(src) | |
local ast = astify(src) | |
print("AST:") | |
print(viewAST(ast, 1)) | |
local gen = sourceify(src) | |
print("generated", gen) | |
local fun = functate(src) | |
print("got fun") | |
local e = eval(src) | |
print(e) | |
end | |
return { | |
eval = eval, | |
astify = astify, | |
sourceify = sourceify, | |
test = test, | |
demo = demo | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment