Skip to content

Instantly share code, notes, and snippets.

@zoon
Last active October 12, 2025 19:23
Show Gist options
  • Save zoon/4cfc4081981d7bc3d1307cc8f1225dd4 to your computer and use it in GitHub Desktop.
Save zoon/4cfc4081981d7bc3d1307cc8f1225dd4 to your computer and use it in GitHub Desktop.
A simple unit testing harness for Luau.
--!strict
--!optimize 2
-- Copyright (c) 2025 Andrei Zhilin https://github.com/zoon
-- Licensed under the MIT License.
--[=[
A simple unit testing harness for Luau.
Usage:
```luau
local m = require("./testit")
local suite = m.suite("example test suite")
suite
:test("example test")
:case("#1", function() m.expect.truthy "A" end)
:case("#2", function() m.expect.equal(false, true) end) -- will fail
:case("#3", function() m.expect.near_equal(1, 2, 0.1) end) -- will fail
suite
:test("another example test")
:case("#11", function() m.expect.deep_equal({ "A" }, { "A" }) end)
:case("#12", function() m.expect.type("A", "string") end)
:case("#13", function() assert(nil, "Expect not to be nil") end) -- plain assert() is ok
suite:run ""
m.test("single test example")
:case("#21", function()
m.expect.throws(function() error "out of range" end, "range")
end)
:case("#22", function() m.expect.not_nil(suite) end)
:run_cases "verbose"
```
Cross platform harness ideas:
$ fd --follow -t file -g *.spec.lua? -x luau.exe -g2 -O1`
$ elvish -c "luau -g2 -O1 **.spec.lua?"
--]=]
-- stylua: ignore start
local CASE_PASS = "PASS"
local CASE_FAIL = "FAIL"
local TEST_PASS = "✅"
local TEST_FAIL = "❌"
local INDENT = " "
local CONNECTOR = " ╰─ "
-- CI mode: if true, treat any test failure as a hard error (non-zero exit code).
local CI_MODE = false
local PRINTER: (msg: string) -> () = print
local m = {}
function m.set_ci_mode(ci: boolean) CI_MODE = ci end
function m.set_printer(printer: (msg: string) -> ()) PRINTER = printer end
function m.reset_printer() PRINTER = print end
export type RunOption = "verbose" | "normal" | ""
type TestInner = {
_name: string,
_cases: { TestCase },
}
type TestImpl = {
__index: TestImpl,
case: (self: Test, name: string, thunk: () -> ()) -> Test,
run_cases: (self: Test, suite_name: string, opt: RunOption?) -> (number, number),
}
type TestCase = {
name: string,
thunk: () -> (),
}
type TestSuiteInner = {
_name: string,
_tests: { Test },
}
type TestSuiteImpl = {
__index: TestSuiteImpl,
test: (self: TestSuite, name: string) -> Test,
run: (self: TestSuite, opt: RunOption?) -> (boolean, number, number),
}
export type TestSuite = typeof(setmetatable({} :: TestSuiteInner, {} :: TestSuiteImpl))
export type Test = typeof(setmetatable({} :: TestInner, {} :: TestImpl))
local Test = {} :: TestImpl
Test.__index = Test
local TestSuite = {} :: TestSuiteImpl
TestSuite.__index = TestSuite
-----------------------------
-- util
-----------------------------
local function split_error(err: string): (string?, string?) return err:match "^(.-:%d+):%s*(.*)%s*$" end
local function format_error(err: any): string
if type(err) ~= "string" then return tostring(err) end
local path, msg = split_error(err)
if path and msg then return msg .. `\n{INDENT}{CONNECTOR}{path}` end
return err
end
local function deep_eq(a: any, b: any, visiting: {}): boolean
if a == b then return true end
local ta, tb = type(a), type(b)
if ta ~= tb or ta ~= "table" then return false end
-- Cycle detection: if we've seen 'a' before, ensure it maps to the same 'b'
if visiting[a] ~= nil then return visiting[a] == b end
visiting[a] = b
-- Compare all key-value pairs
local keys_seen = {}
for key, value_a in a do
keys_seen[key] = true
if not deep_eq(value_a, b[key], visiting) then return false end
end
-- Check for extra keys in b
for key in b do
if not keys_seen[key] then return false end
end
return true
end
local function check_string(s: any, name: string)
if type(s) ~= "string" then error(`Expected {name} to be a string, got {type(s)}`, 3) end
end
local function check_number(n: any, name: string)
if type(n) ~= "number" then error(`Expected {name} to be a number, got {type(n)}`, 3) end
if n ~= n then error(`Expected {name} to be a number, got NaN`, 3) end
if n == math.huge then error(`Expected {name} to be a number, got +inf`, 3) end
if n == -math.huge then error(`Expected {name} to be a number, got -inf`, 3) end
end
-----------------------------
-- assertions
-----------------------------
m.expect = {}
--- Assert two numbers are approximately equal within a tolerance.
--- @param actual number Measured value.
--- @param expected number Expected value.
--- @param tolerance? number Absolute tolerance; defaults to 1e-6.
--- @error Throws when |actual - expected| > tol or types mismatch.
function m.expect.near_equal(actual: number, expected: number, tolerance: number?)
local tol = tolerance and type(tolerance) == "number" and tolerance or 1e-6
check_number(tol, "tol")
check_number(actual, "actual")
check_number(expected, "expected")
tol = math.abs(tol)
local delta = math.abs(expected - actual)
if delta > tol then
local denom = math.max(math.abs(expected), 1e-12)
local ratio = delta / denom
error(`Expected {actual} to be near {expected} (Δ={delta}, ε={ratio}, τ={tol})`, 2)
end
end
--- Assert a number is an integer (no fractional part).
--- @param n number Value to check.
--- @error Throws when n is not an integer.
function m.expect.integer(n: number)
check_number(n, "n")
if math.floor(n) ~= n then error(`Expected {n} to be an integer`, 2) end
end
--- Assert actual < expected.
--- @param actual number Value to compare.
--- @param expected number Upper bound (exclusive).
--- @error Throws when actual >= expected.
function m.expect.less(actual: number, expected: number)
check_number(actual, "actual")
check_number(expected, "expected")
if actual >= expected then error(`Expected {actual} < {expected}`, 2) end
end
--- Assert actual <= expected.
--- @param actual number Value to compare.
--- @param expected number Upper bound (inclusive).
--- @error Throws when actual > expected.
function m.expect.less_or_equal(actual: number, expected: number)
check_number(actual, "actual")
check_number(expected, "expected")
if actual > expected then error(`Expected {actual} <= {expected}`, 2) end
end
--- Assert actual < expected in lexicographic order.
--- @param actual string String to compare.
--- @param expected string Upper bound (exclusive).
--- @error Throws when actual >= expected lexicographically.
function m.expect.lex_less(actual: string, expected: string)
check_string(actual, "actual")
check_string(expected, "expected")
if actual >= expected then error(`Expected '{actual}' < '{expected}' lexicographically`, 2) end
end
--- Assert actual <= expected in lexicographic order.
--- @param actual string String to compare.
--- @param expected string Upper bound (inclusive).
--- @error Throws when actual > expected lexicographically.
function m.expect.lex_less_or_equal(actual: string, expected: string)
check_string(actual, "actual")
check_string(expected, "expected")
if actual > expected then error(`Expected '{actual}' <= '{expected}' lexicographically`, 2) end
end
--- Assert the Luau typeof(value) equals the expected type name.
--- @param value unknown Value to inspect.
--- @param expected string Type name, e.g. "string", "number", "table".
--- @error Throws when typeof(value) ~= expected.
function m.expect.type(value: unknown, expected: string)
local actual = typeof(value)
if actual ~= expected then error(`Expected '{actual}' to be type '{expected}'`, 2) end
end
--- Deep equality comparison with optional metatable checking
--- @param strict_mt If true, metatables must also be equal
function m.expect.deep_equal<T>(actual: T, expected: T, strict_mt: boolean?)
local ta, te = type(actual), type(expected)
if ta ~= te then error(`Expected type to be '{te}', got '{ta}'`, 2) end
if ta == "table" and te == "table" then
if strict_mt then
local mt_a, mt_b = getmetatable(actual :: any), getmetatable(expected :: any)
if mt_a ~= mt_b then error(`Expected metatable {mt_a} to equal {mt_b}`, 2) end
end
if not deep_eq(actual, expected, {}) then
error(`Expected {actual} to deep equal {expected}`, 2)
end
else
if actual ~= expected then error(`Expected {actual} to equal {expected}`, 2) end
end
end
--- Assert strict equality (including matching types).
--- @param actual unknown
--- @param expected unknown
--- @error Throws when types differ or values are not equal.
function m.expect.equal<T>(actual: T, expected: T)
local ta, te = type(actual), type(expected)
if ta ~= te then error(`Expected type to be '{te}', got '{ta}'`, 2) end
if actual ~= expected then error(`Expected {actual} to equal {expected}`, 2) end
end
--- Assert value is truthy. Ignores rest arguments.
--- @param value unknown
--- @error Throws when value is falsy.
function m.expect.truthy(value: unknown, ...)
if not value then error(`Expected {value} value to be truthy`, 2) end
end
--- Assert value is falsy. Ignores rest arguments.
--- @param value unknown
--- @error Throws when value is truthy.
function m.expect.falsy(value: unknown, ...)
if value then error(`Expected {value} value to be falsy`, 2) end
end
--- Assert value is not nil. Ignores rest arguments.
--- @param value unknown
--- @error Throws when value is nil.
function m.expect.not_nil(value: unknown, ...)
if value == nil then error(`Expected value not to be nil`, 2) end
end
--- Assert that a function throws an error; optionally check message contains a substring.
--- @param thunk ()->() Function expected to throw.
--- @param expected_msg? string Substring that must appear in the error message.
--- @error Throws when the function does not throw or message does not match.
function m.expect.throws(thunk: () -> (), expected_msg: string?)
local ok, err: any? = pcall(thunk)
if ok then error("Expected function to throw an error", 2) end
if not expected_msg then return end
err = tostring(err)
local _, err_msg = split_error(err :: string)
local err_str = err_msg or err
if expected_msg and not err_str:find(expected_msg, 1, true) then
error(`Error message {err_str} does not contain expected "{expected_msg}"`, 2)
end
end
--- Force a test failure with a message.
--- @param msg string Human-readable reason for the failure.
--- @error Always throws with the provided message.
function m.expect.fail(msg: string) error(tostring(msg), 2) end
--- Assert that a function does not throw an error.
--- @param thunk ()->() Function expected not to throw.
--- @error Throws when the function throws any error.
function m.expect.not_throws(thunk: () -> ())
local ok, err: any? = pcall(thunk)
if not ok then error(`Expected function not to throw, got: {format_error(err)}`, 2) end
end
-----------------------------
-- constructors
-----------------------------
--- Create a new test suite.
--- @param name string Suite name for reporting.
--- @return TestSuite The suite handle.
function m.suite(name: string): TestSuite
return setmetatable({ _name = name, _tests = {} } :: TestSuiteInner, TestSuite :: TestSuiteImpl)
end
--- Create a new test container for cases.
--- @param name string Test name for reporting.
--- @return Test The test handle.
local function _test(name: string): Test
return setmetatable({ _name = name, _cases = {} } :: TestInner, Test :: TestImpl)
end
-------------------------------
-- methods
-------------------------------
function TestSuite.test(self: TestSuite, name: string): Test
local t = _test(name)
table.insert(self._tests, t)
return t
end
--- Run all tests in the suite.
--- @param opt? RunOption "verbose" prints per-case details and a summary.
--- @return number Cases run.
--- @return number Cases failed.
function TestSuite.run(self: TestSuite, opt: RunOption?)
if opt == "verbose" then
local fmt = "[RUNNING] '%*'"
PRINTER(fmt:format(self._name))
end
local count, failed = 0, 0
for _, test in self._tests do
local c, f = Test.run_cases(test, self._name, opt)
count += c
failed += f
end
local fmt = "[SUMMARY] '%*': total: %d, passed: %d, failed: %d"
local summary = fmt:format(self._name, count, count - failed, failed)
if failed > 0 then
PRINTER(summary)
if CI_MODE then
error("", 2) -- for non -zero exit code in CI mode
end
elseif opt == "verbose" then
PRINTER(summary)
end
return failed == 0, count, failed
end
--- Add a case (named thunk) to this test.
--- @param name string Case identifier shown in output.
--- @param thunk ()->() Test body; should throw on failure.
--- @return Test Returns self for chaining.
function Test.case(self: Test, name: string, thunk: () -> ())
table.insert(self._cases, { name = name, thunk = thunk })
return self
end
function Test.run_cases(self: Test, suite_name: string, opt: RunOption?)
local results = {}
local failed = 0
for _, case in self._cases do
local ok, err: string? = xpcall(
case.thunk,
function(err) return `{INDENT}{CASE_FAIL} {case.name} -- {format_error(err)}` end
)
if ok then
if opt == "verbose" then table.insert(results, `{INDENT}{CASE_PASS} {case.name}`) end
else
failed += 1
table.insert(results, err :: string)
end
end
if failed == 0 then
PRINTER(`{TEST_PASS} [{suite_name}] {self._name}`)
if opt == "verbose" then
if #results > 0 then PRINTER(table.concat(results, "\n")) end
end
else
PRINTER(`{TEST_FAIL} [{suite_name}] {self._name}`)
PRINTER(table.concat(results, "\n"))
end
return #self._cases, failed
end
-----------------------------
-- quick bench submodule
-----------------------------
m.bench = {}
do
type Fn = (any) -> any
local MIN_BENCH_DURATION = 0.100 -- 100 ms
local MAX_BENCH_DURATION = 5.0 -- seconds
local TARGET_REL_STDDEV = 0.02 -- 2%
local _SINK = nil
local function _barrier()
if _SINK == nil then _SINK = false end
return _SINK
end
-- Ordinary Least Squares regression through origin: y = kx
-- Returns slope k (seconds per iteration)
local function _ols_slope(iters: { number }, elapsed: { number }): number
local sum_x2, sum_xy = 0, 0
for i = 1, #iters do
local x, y = iters[i], elapsed[i]
sum_x2 = sum_x2 + x * x
sum_xy = sum_xy + x * y
end
return sum_xy / sum_x2
end
local function _sample(routine: Fn, n: number, param: any)
local start = os.clock()
for i = 1, n do
_SINK = routine(param)
end
local dt = os.clock() - start
return dt, _barrier()
end
local function _bench(fn: Fn, parameter): BenchSummary
local time_start = os.clock()
local n = 1
local total_iter = n
local t_prev = _sample(fn, n, parameter)
-- Early exit for extremely long-running benchmarks
if (os.clock() - time_start) > MAX_BENCH_DURATION then
-- Synthetic duplicate sample to avoid downstream division issues
local iters = { n, n }
local elapsed = { t_prev, t_prev + 1e-6 }
local per_iter = _ols_slope(iters, elapsed)
return {
sec_per_iter = per_iter,
rel_noise = 0,
total_iter = total_iter,
warn = "<timeout>",
}
end
-- Main data collection loop
while true do
local t_now = _sample(fn, n * 2, parameter)
total_iter += (n * 2)
-- OLS slope through origin from (1,t_prev), (2,t_now): time per block.
local k = (t_prev + 2.0 * t_now) / 5.0
local stdev = math.sqrt((t_prev - k) ^ 2 + (t_now - 2.0 * k) ^ 2)
local elapsed_wall = os.clock() - time_start
if
(stdev < TARGET_REL_STDDEV * k and elapsed_wall > MIN_BENCH_DURATION)
or elapsed_wall > MAX_BENCH_DURATION
then
local iters = { n, n * 2 }
local elapsed = { t_prev, t_now }
local per_iter = _ols_slope(iters, elapsed)
-- Residual norm and relative noise
local rss = 0
for i = 1, #iters do
local pred = per_iter * iters[i]
local r = elapsed[i] - pred
rss = rss + r * r
end
local rel_noise = math.sqrt(rss) / (per_iter * iters[1])
local timeout = (stdev > TARGET_REL_STDDEV * k)
local warn = ""
if timeout then
warn = "<timeout>"
elseif rel_noise >= TARGET_REL_STDDEV then
warn = "<unstable>"
end
return {
sec_per_iter = per_iter,
rel_noise = rel_noise,
total_iter = total_iter,
warn = warn,
}
end
n = n * 2
t_prev = t_now
end
end
-- stylua: ignore
local function _fmt_time(seconds: number): string
local ns = seconds * 1e9
if ns < 1e3 then return string.format("%.2f ns", ns) end; ns /= 1000
--- @note: use 'us', not 'μs' as it may be different width in some terminals
if ns < 1e3 then return string.format("%.2f us", ns) end; ns /= 1000
if ns < 1e3 then return string.format("%.2f ms", ns) end; ns /= 1000
return string.format("%.2f s", ns)
end
-- stylua: ignore
local function _fmt_num(v: number): string
if math.floor(v) == v then
local vlg2 = math.log(v, 2)
if v > 0 and vlg2 == math.floor(vlg2) then
if v < 1024 then return string.format("%d ", v) end; v /= 1024
if v < 1024 then return string.format("%dK", v) end; v /= 1024
if v < 1024 then return string.format("%dM", v) end; v /= 1024
return string.format("%dG", v)
else
return string.format("%d", v)
end
end
return string.format("%g", v)
end
--stylua: ignore
local function _fmt_arg(arg: any): string
local t = typeof(arg)
if t == "nil" then return "" end
if t == "string" then return if #arg > 10 then string.format("%q...", arg:sub(1, 7)) else string.format("%q", arg) end
if t == "number" then return _fmt_num(arg) end
if t == "table" then
local c = 0; for _ in arg do c += 1 end
return if c == #arg then `array[{_fmt_num(c)}]` else `<table:{_fmt_num(c)}>`
end
if t == "buffer" then return `<buffer:{_fmt_num(buffer.len(arg))}>` end
return tostring(arg)
end
--- Summary statistics produced by the benchmark.
--- @field sec_per_iter number Estimated seconds per iteration
--- @field rel_noise number Approximate relative standard deviation
--- @field total_iter number Total iterations executed across samples
--- @field warn? string Optional warning marker ("<timeout>" or "<unstable>")
export type BenchSummary = {
sec_per_iter: number,
rel_noise: number,
total_iter: number,
warn: string,
}
--- Comparison result for a single routine.
--- @field tag string Label for this routine
--- @field summary BenchSummary Benchmark results
export type CompareResult = {
tag: string,
summary: BenchSummary,
}
local function fmt_summary(s: BenchSummary): string
local warn = if s.warn ~= "" then " " .. s.warn else ""
return string.format(
"%9s/iter ±%.2f%%%s",
_fmt_time(s.sec_per_iter),
s.rel_noise * 100.0,
warn
)
end
--- Run a quick benchmark and print a one-line summary.
--- @param tag string Label for the benchmark row.
--- @param routine fun(arg:any):any Function under test; returns are discarded.
--- @param arg? any Optional argument passed to the routine.
function m.bench.run(tag: string, routine: Fn, arg: any?)
local summary = _bench(routine, arg)
local arg_fmt = _fmt_arg(arg)
local label = if arg_fmt ~= "" then tag .. " " .. arg_fmt else tag
PRINTER(string.format("%-40s %*", label, fmt_summary(summary)))
end
--- Measure a benchmark without printing results.
--- @param routine fun(arg:any):any Function under test.
--- @param arg? any Optional argument passed to the routine.
--- @return BenchSummary Summary of timing results.
function m.bench.measure(routine: Fn, arg: any?): BenchSummary
return _bench(routine, arg)
end
--- @deprecated alias for backward compatibility
m.bench.sample = m.bench.measure
--- Configure benchmark timing parameters.
--- @param min_duration? number Minimum benchmark duration in seconds (default: 0.1).
--- @param max_duration? number Maximum benchmark duration in seconds (default: 5.0).
--- @param target_rel_stddev? number Target relative standard deviation (default: 0.02).
function m.bench.configure(min_duration: number?, max_duration: number?, target_rel_stddev: number?)
if min_duration then MIN_BENCH_DURATION = min_duration end
if max_duration then MAX_BENCH_DURATION = max_duration end
if target_rel_stddev then TARGET_REL_STDDEV = target_rel_stddev end
end
--- Get current benchmark configuration.
--- @return number min_duration Minimum benchmark duration in seconds.
--- @return number max_duration Maximum benchmark duration in seconds.
--- @return number target_rel_stddev Target relative standard deviation.
function m.bench.get_config(): (number, number, number)
return MIN_BENCH_DURATION, MAX_BENCH_DURATION, TARGET_REL_STDDEV
end
--- Measure and compare multiple routines without printing.
--- @param arg? any Argument passed to all routines.
--- @param routines {[string]: Fn} Table mapping tags to routines.
--- @return {CompareResult} Array of results sorted by speed (fastest first).
function m.bench.measure_compare(arg: any?, routines: { [string]: Fn }): { CompareResult }
-- Validate input
if type(routines) ~= "table" then
error(`Expected routines to be a table, got {type(routines)}`, 2)
end
-- Collect benchmark results
local results: { CompareResult } = {}
for tag, routine in routines do
if type(routine) ~= "function" then
error(`Expected routine for "{tag}" to be a function, got {type(routine)}`, 2)
end
local summary = _bench(routine, arg)
table.insert(results, { tag = tag, summary = summary })
end
-- Sort by speed (fastest first)
table.sort(results, function(a, b) return a.summary.sec_per_iter < b.summary.sec_per_iter end)
return results
end
--- Compare multiple routines with the same argument and print results.
--- @param arg? any Argument passed to all routines.
--- @param routines {[string]: Fn} Table mapping tags to routines.
function m.bench.compare(arg: any?, routines: { [string]: Fn })
local results = m.bench.measure_compare(arg, routines)
-- Handle edge cases
if #results == 0 then
PRINTER("No routines to compare")
return
end
if #results == 1 then
local result = results[1]
local arg_fmt = _fmt_arg(arg)
local label = if arg_fmt ~= "" then result.tag .. " " .. arg_fmt else result.tag
PRINTER(string.format("%-40s %*", label, fmt_summary(result.summary)))
return
end
-- Print header
local arg_fmt = _fmt_arg(arg)
local header = if arg_fmt ~= "" then `Comparing with {arg_fmt}:` else "Comparing:"
PRINTER(header)
-- Print results with comparison
local baseline = results[1].summary.sec_per_iter
for i, result in results do
local tag = result.tag
local summary = result.summary
local time_str = fmt_summary(summary)
if i == 1 then
-- Fastest is baseline
PRINTER(string.format(" %-38s %* [baseline]", tag, time_str))
else
-- Show relative performance
local ratio = summary.sec_per_iter / baseline
local rel_str = string.format("(%.2fx slower)", ratio)
PRINTER(string.format(" %-38s %* %s", tag, time_str, rel_str))
end
end
end
end
--[[ quick demo
-- benchmark demo
local function fib(n: number): number
if n <= 1 then return n end
return fib(n - 1) + fib(n - 2)
end
for arg = 20, 30, 2 do
m.bench.run("fibbonacci of", fib, arg)
end
-- test demo
local suite = m.suite "example test suite"
suite
:test("example test")
:case("#1", function() m.expect.truthy "A" end)
:case("#2", function() m.expect.equal(false, true) end) -- will fail
:case("#3", function() m.expect.near_equal(1, 2, 0.1) end) -- will fail
suite
:test("another example test")
:case("#11", function() m.expect.deep_equal({ "A" }, { "A" }) end)
:case("#12", function() m.expect.type("A", "string") end)
:case("#13", function() assert(nil, "Expect not to be nil") end) -- plain assert is ok
suite
:test("single test example")
:case("#21", function()
m.expect.throws(function() error "out of range" end, "range")
end)
:case("#22", function() m.expect.not_nil(suite) end)
suite
:test("extras: fail & not_throws")
:case("#31 fail throws", function()
m.expect.throws(function() m.expect.fail "boom" end, "boom")
end)
:case("#32 not_throws ok", function()
m.expect.not_throws(function() end)
end)
:case("#33 not_throws fails", function()
m.expect.throws(function()
m.expect.not_throws(function() error "kaboom" end)
end, "not to throw")
end)
m.set_ci_mode(false) -- true to set CI mode
suite:run "normal"
--]]
-- stylua: ignore end
return m
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment