Last active
October 12, 2025 19:23
-
-
Save zoon/4cfc4081981d7bc3d1307cc8f1225dd4 to your computer and use it in GitHub Desktop.
A simple unit testing harness for Luau.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--!strict | |
--!optimize 2 | |
-- Copyright (c) 2025 Andrei Zhilin https://github.com/zoon | |
-- Licensed under the MIT License. | |
--[=[ | |
A simple unit testing harness for Luau. | |
Usage: | |
```luau | |
local m = require("./testit") | |
local suite = m.suite("example test suite") | |
suite | |
:test("example test") | |
:case("#1", function() m.expect.truthy "A" end) | |
:case("#2", function() m.expect.equal(false, true) end) -- will fail | |
:case("#3", function() m.expect.near_equal(1, 2, 0.1) end) -- will fail | |
suite | |
:test("another example test") | |
:case("#11", function() m.expect.deep_equal({ "A" }, { "A" }) end) | |
:case("#12", function() m.expect.type("A", "string") end) | |
:case("#13", function() assert(nil, "Expect not to be nil") end) -- plain assert() is ok | |
suite:run "" | |
m.test("single test example") | |
:case("#21", function() | |
m.expect.throws(function() error "out of range" end, "range") | |
end) | |
:case("#22", function() m.expect.not_nil(suite) end) | |
:run_cases "verbose" | |
``` | |
Cross platform harness ideas: | |
$ fd --follow -t file -g *.spec.lua? -x luau.exe -g2 -O1` | |
$ elvish -c "luau -g2 -O1 **.spec.lua?" | |
--]=] | |
-- stylua: ignore start | |
local CASE_PASS = "PASS" | |
local CASE_FAIL = "FAIL" | |
local TEST_PASS = "✅" | |
local TEST_FAIL = "❌" | |
local INDENT = " " | |
local CONNECTOR = " ╰─ " | |
-- CI mode: if true, treat any test failure as a hard error (non-zero exit code). | |
local CI_MODE = false | |
local PRINTER: (msg: string) -> () = print | |
local m = {} | |
function m.set_ci_mode(ci: boolean) CI_MODE = ci end | |
function m.set_printer(printer: (msg: string) -> ()) PRINTER = printer end | |
function m.reset_printer() PRINTER = print end | |
export type RunOption = "verbose" | "normal" | "" | |
type TestInner = { | |
_name: string, | |
_cases: { TestCase }, | |
} | |
type TestImpl = { | |
__index: TestImpl, | |
case: (self: Test, name: string, thunk: () -> ()) -> Test, | |
run_cases: (self: Test, suite_name: string, opt: RunOption?) -> (number, number), | |
} | |
type TestCase = { | |
name: string, | |
thunk: () -> (), | |
} | |
type TestSuiteInner = { | |
_name: string, | |
_tests: { Test }, | |
} | |
type TestSuiteImpl = { | |
__index: TestSuiteImpl, | |
test: (self: TestSuite, name: string) -> Test, | |
run: (self: TestSuite, opt: RunOption?) -> (boolean, number, number), | |
} | |
export type TestSuite = typeof(setmetatable({} :: TestSuiteInner, {} :: TestSuiteImpl)) | |
export type Test = typeof(setmetatable({} :: TestInner, {} :: TestImpl)) | |
local Test = {} :: TestImpl | |
Test.__index = Test | |
local TestSuite = {} :: TestSuiteImpl | |
TestSuite.__index = TestSuite | |
----------------------------- | |
-- util | |
----------------------------- | |
local function split_error(err: string): (string?, string?) return err:match "^(.-:%d+):%s*(.*)%s*$" end | |
local function format_error(err: any): string | |
if type(err) ~= "string" then return tostring(err) end | |
local path, msg = split_error(err) | |
if path and msg then return msg .. `\n{INDENT}{CONNECTOR}{path}` end | |
return err | |
end | |
local function deep_eq(a: any, b: any, visiting: {}): boolean | |
if a == b then return true end | |
local ta, tb = type(a), type(b) | |
if ta ~= tb or ta ~= "table" then return false end | |
-- Cycle detection: if we've seen 'a' before, ensure it maps to the same 'b' | |
if visiting[a] ~= nil then return visiting[a] == b end | |
visiting[a] = b | |
-- Compare all key-value pairs | |
local keys_seen = {} | |
for key, value_a in a do | |
keys_seen[key] = true | |
if not deep_eq(value_a, b[key], visiting) then return false end | |
end | |
-- Check for extra keys in b | |
for key in b do | |
if not keys_seen[key] then return false end | |
end | |
return true | |
end | |
local function check_string(s: any, name: string) | |
if type(s) ~= "string" then error(`Expected {name} to be a string, got {type(s)}`, 3) end | |
end | |
local function check_number(n: any, name: string) | |
if type(n) ~= "number" then error(`Expected {name} to be a number, got {type(n)}`, 3) end | |
if n ~= n then error(`Expected {name} to be a number, got NaN`, 3) end | |
if n == math.huge then error(`Expected {name} to be a number, got +inf`, 3) end | |
if n == -math.huge then error(`Expected {name} to be a number, got -inf`, 3) end | |
end | |
----------------------------- | |
-- assertions | |
----------------------------- | |
m.expect = {} | |
--- Assert two numbers are approximately equal within a tolerance. | |
--- @param actual number Measured value. | |
--- @param expected number Expected value. | |
--- @param tolerance? number Absolute tolerance; defaults to 1e-6. | |
--- @error Throws when |actual - expected| > tol or types mismatch. | |
function m.expect.near_equal(actual: number, expected: number, tolerance: number?) | |
local tol = tolerance and type(tolerance) == "number" and tolerance or 1e-6 | |
check_number(tol, "tol") | |
check_number(actual, "actual") | |
check_number(expected, "expected") | |
tol = math.abs(tol) | |
local delta = math.abs(expected - actual) | |
if delta > tol then | |
local denom = math.max(math.abs(expected), 1e-12) | |
local ratio = delta / denom | |
error(`Expected {actual} to be near {expected} (Δ={delta}, ε={ratio}, τ={tol})`, 2) | |
end | |
end | |
--- Assert a number is an integer (no fractional part). | |
--- @param n number Value to check. | |
--- @error Throws when n is not an integer. | |
function m.expect.integer(n: number) | |
check_number(n, "n") | |
if math.floor(n) ~= n then error(`Expected {n} to be an integer`, 2) end | |
end | |
--- Assert actual < expected. | |
--- @param actual number Value to compare. | |
--- @param expected number Upper bound (exclusive). | |
--- @error Throws when actual >= expected. | |
function m.expect.less(actual: number, expected: number) | |
check_number(actual, "actual") | |
check_number(expected, "expected") | |
if actual >= expected then error(`Expected {actual} < {expected}`, 2) end | |
end | |
--- Assert actual <= expected. | |
--- @param actual number Value to compare. | |
--- @param expected number Upper bound (inclusive). | |
--- @error Throws when actual > expected. | |
function m.expect.less_or_equal(actual: number, expected: number) | |
check_number(actual, "actual") | |
check_number(expected, "expected") | |
if actual > expected then error(`Expected {actual} <= {expected}`, 2) end | |
end | |
--- Assert actual < expected in lexicographic order. | |
--- @param actual string String to compare. | |
--- @param expected string Upper bound (exclusive). | |
--- @error Throws when actual >= expected lexicographically. | |
function m.expect.lex_less(actual: string, expected: string) | |
check_string(actual, "actual") | |
check_string(expected, "expected") | |
if actual >= expected then error(`Expected '{actual}' < '{expected}' lexicographically`, 2) end | |
end | |
--- Assert actual <= expected in lexicographic order. | |
--- @param actual string String to compare. | |
--- @param expected string Upper bound (inclusive). | |
--- @error Throws when actual > expected lexicographically. | |
function m.expect.lex_less_or_equal(actual: string, expected: string) | |
check_string(actual, "actual") | |
check_string(expected, "expected") | |
if actual > expected then error(`Expected '{actual}' <= '{expected}' lexicographically`, 2) end | |
end | |
--- Assert the Luau typeof(value) equals the expected type name. | |
--- @param value unknown Value to inspect. | |
--- @param expected string Type name, e.g. "string", "number", "table". | |
--- @error Throws when typeof(value) ~= expected. | |
function m.expect.type(value: unknown, expected: string) | |
local actual = typeof(value) | |
if actual ~= expected then error(`Expected '{actual}' to be type '{expected}'`, 2) end | |
end | |
--- Deep equality comparison with optional metatable checking | |
--- @param strict_mt If true, metatables must also be equal | |
function m.expect.deep_equal<T>(actual: T, expected: T, strict_mt: boolean?) | |
local ta, te = type(actual), type(expected) | |
if ta ~= te then error(`Expected type to be '{te}', got '{ta}'`, 2) end | |
if ta == "table" and te == "table" then | |
if strict_mt then | |
local mt_a, mt_b = getmetatable(actual :: any), getmetatable(expected :: any) | |
if mt_a ~= mt_b then error(`Expected metatable {mt_a} to equal {mt_b}`, 2) end | |
end | |
if not deep_eq(actual, expected, {}) then | |
error(`Expected {actual} to deep equal {expected}`, 2) | |
end | |
else | |
if actual ~= expected then error(`Expected {actual} to equal {expected}`, 2) end | |
end | |
end | |
--- Assert strict equality (including matching types). | |
--- @param actual unknown | |
--- @param expected unknown | |
--- @error Throws when types differ or values are not equal. | |
function m.expect.equal<T>(actual: T, expected: T) | |
local ta, te = type(actual), type(expected) | |
if ta ~= te then error(`Expected type to be '{te}', got '{ta}'`, 2) end | |
if actual ~= expected then error(`Expected {actual} to equal {expected}`, 2) end | |
end | |
--- Assert value is truthy. Ignores rest arguments. | |
--- @param value unknown | |
--- @error Throws when value is falsy. | |
function m.expect.truthy(value: unknown, ...) | |
if not value then error(`Expected {value} value to be truthy`, 2) end | |
end | |
--- Assert value is falsy. Ignores rest arguments. | |
--- @param value unknown | |
--- @error Throws when value is truthy. | |
function m.expect.falsy(value: unknown, ...) | |
if value then error(`Expected {value} value to be falsy`, 2) end | |
end | |
--- Assert value is not nil. Ignores rest arguments. | |
--- @param value unknown | |
--- @error Throws when value is nil. | |
function m.expect.not_nil(value: unknown, ...) | |
if value == nil then error(`Expected value not to be nil`, 2) end | |
end | |
--- Assert that a function throws an error; optionally check message contains a substring. | |
--- @param thunk ()->() Function expected to throw. | |
--- @param expected_msg? string Substring that must appear in the error message. | |
--- @error Throws when the function does not throw or message does not match. | |
function m.expect.throws(thunk: () -> (), expected_msg: string?) | |
local ok, err: any? = pcall(thunk) | |
if ok then error("Expected function to throw an error", 2) end | |
if not expected_msg then return end | |
err = tostring(err) | |
local _, err_msg = split_error(err :: string) | |
local err_str = err_msg or err | |
if expected_msg and not err_str:find(expected_msg, 1, true) then | |
error(`Error message {err_str} does not contain expected "{expected_msg}"`, 2) | |
end | |
end | |
--- Force a test failure with a message. | |
--- @param msg string Human-readable reason for the failure. | |
--- @error Always throws with the provided message. | |
function m.expect.fail(msg: string) error(tostring(msg), 2) end | |
--- Assert that a function does not throw an error. | |
--- @param thunk ()->() Function expected not to throw. | |
--- @error Throws when the function throws any error. | |
function m.expect.not_throws(thunk: () -> ()) | |
local ok, err: any? = pcall(thunk) | |
if not ok then error(`Expected function not to throw, got: {format_error(err)}`, 2) end | |
end | |
----------------------------- | |
-- constructors | |
----------------------------- | |
--- Create a new test suite. | |
--- @param name string Suite name for reporting. | |
--- @return TestSuite The suite handle. | |
function m.suite(name: string): TestSuite | |
return setmetatable({ _name = name, _tests = {} } :: TestSuiteInner, TestSuite :: TestSuiteImpl) | |
end | |
--- Create a new test container for cases. | |
--- @param name string Test name for reporting. | |
--- @return Test The test handle. | |
local function _test(name: string): Test | |
return setmetatable({ _name = name, _cases = {} } :: TestInner, Test :: TestImpl) | |
end | |
------------------------------- | |
-- methods | |
------------------------------- | |
function TestSuite.test(self: TestSuite, name: string): Test | |
local t = _test(name) | |
table.insert(self._tests, t) | |
return t | |
end | |
--- Run all tests in the suite. | |
--- @param opt? RunOption "verbose" prints per-case details and a summary. | |
--- @return number Cases run. | |
--- @return number Cases failed. | |
function TestSuite.run(self: TestSuite, opt: RunOption?) | |
if opt == "verbose" then | |
local fmt = "[RUNNING] '%*'" | |
PRINTER(fmt:format(self._name)) | |
end | |
local count, failed = 0, 0 | |
for _, test in self._tests do | |
local c, f = Test.run_cases(test, self._name, opt) | |
count += c | |
failed += f | |
end | |
local fmt = "[SUMMARY] '%*': total: %d, passed: %d, failed: %d" | |
local summary = fmt:format(self._name, count, count - failed, failed) | |
if failed > 0 then | |
PRINTER(summary) | |
if CI_MODE then | |
error("", 2) -- for non -zero exit code in CI mode | |
end | |
elseif opt == "verbose" then | |
PRINTER(summary) | |
end | |
return failed == 0, count, failed | |
end | |
--- Add a case (named thunk) to this test. | |
--- @param name string Case identifier shown in output. | |
--- @param thunk ()->() Test body; should throw on failure. | |
--- @return Test Returns self for chaining. | |
function Test.case(self: Test, name: string, thunk: () -> ()) | |
table.insert(self._cases, { name = name, thunk = thunk }) | |
return self | |
end | |
function Test.run_cases(self: Test, suite_name: string, opt: RunOption?) | |
local results = {} | |
local failed = 0 | |
for _, case in self._cases do | |
local ok, err: string? = xpcall( | |
case.thunk, | |
function(err) return `{INDENT}{CASE_FAIL} {case.name} -- {format_error(err)}` end | |
) | |
if ok then | |
if opt == "verbose" then table.insert(results, `{INDENT}{CASE_PASS} {case.name}`) end | |
else | |
failed += 1 | |
table.insert(results, err :: string) | |
end | |
end | |
if failed == 0 then | |
PRINTER(`{TEST_PASS} [{suite_name}] {self._name}`) | |
if opt == "verbose" then | |
if #results > 0 then PRINTER(table.concat(results, "\n")) end | |
end | |
else | |
PRINTER(`{TEST_FAIL} [{suite_name}] {self._name}`) | |
PRINTER(table.concat(results, "\n")) | |
end | |
return #self._cases, failed | |
end | |
----------------------------- | |
-- quick bench submodule | |
----------------------------- | |
m.bench = {} | |
do | |
type Fn = (any) -> any | |
local MIN_BENCH_DURATION = 0.100 -- 100 ms | |
local MAX_BENCH_DURATION = 5.0 -- seconds | |
local TARGET_REL_STDDEV = 0.02 -- 2% | |
local _SINK = nil | |
local function _barrier() | |
if _SINK == nil then _SINK = false end | |
return _SINK | |
end | |
-- Ordinary Least Squares regression through origin: y = kx | |
-- Returns slope k (seconds per iteration) | |
local function _ols_slope(iters: { number }, elapsed: { number }): number | |
local sum_x2, sum_xy = 0, 0 | |
for i = 1, #iters do | |
local x, y = iters[i], elapsed[i] | |
sum_x2 = sum_x2 + x * x | |
sum_xy = sum_xy + x * y | |
end | |
return sum_xy / sum_x2 | |
end | |
local function _sample(routine: Fn, n: number, param: any) | |
local start = os.clock() | |
for i = 1, n do | |
_SINK = routine(param) | |
end | |
local dt = os.clock() - start | |
return dt, _barrier() | |
end | |
local function _bench(fn: Fn, parameter): BenchSummary | |
local time_start = os.clock() | |
local n = 1 | |
local total_iter = n | |
local t_prev = _sample(fn, n, parameter) | |
-- Early exit for extremely long-running benchmarks | |
if (os.clock() - time_start) > MAX_BENCH_DURATION then | |
-- Synthetic duplicate sample to avoid downstream division issues | |
local iters = { n, n } | |
local elapsed = { t_prev, t_prev + 1e-6 } | |
local per_iter = _ols_slope(iters, elapsed) | |
return { | |
sec_per_iter = per_iter, | |
rel_noise = 0, | |
total_iter = total_iter, | |
warn = "<timeout>", | |
} | |
end | |
-- Main data collection loop | |
while true do | |
local t_now = _sample(fn, n * 2, parameter) | |
total_iter += (n * 2) | |
-- OLS slope through origin from (1,t_prev), (2,t_now): time per block. | |
local k = (t_prev + 2.0 * t_now) / 5.0 | |
local stdev = math.sqrt((t_prev - k) ^ 2 + (t_now - 2.0 * k) ^ 2) | |
local elapsed_wall = os.clock() - time_start | |
if | |
(stdev < TARGET_REL_STDDEV * k and elapsed_wall > MIN_BENCH_DURATION) | |
or elapsed_wall > MAX_BENCH_DURATION | |
then | |
local iters = { n, n * 2 } | |
local elapsed = { t_prev, t_now } | |
local per_iter = _ols_slope(iters, elapsed) | |
-- Residual norm and relative noise | |
local rss = 0 | |
for i = 1, #iters do | |
local pred = per_iter * iters[i] | |
local r = elapsed[i] - pred | |
rss = rss + r * r | |
end | |
local rel_noise = math.sqrt(rss) / (per_iter * iters[1]) | |
local timeout = (stdev > TARGET_REL_STDDEV * k) | |
local warn = "" | |
if timeout then | |
warn = "<timeout>" | |
elseif rel_noise >= TARGET_REL_STDDEV then | |
warn = "<unstable>" | |
end | |
return { | |
sec_per_iter = per_iter, | |
rel_noise = rel_noise, | |
total_iter = total_iter, | |
warn = warn, | |
} | |
end | |
n = n * 2 | |
t_prev = t_now | |
end | |
end | |
-- stylua: ignore | |
local function _fmt_time(seconds: number): string | |
local ns = seconds * 1e9 | |
if ns < 1e3 then return string.format("%.2f ns", ns) end; ns /= 1000 | |
--- @note: use 'us', not 'μs' as it may be different width in some terminals | |
if ns < 1e3 then return string.format("%.2f us", ns) end; ns /= 1000 | |
if ns < 1e3 then return string.format("%.2f ms", ns) end; ns /= 1000 | |
return string.format("%.2f s", ns) | |
end | |
-- stylua: ignore | |
local function _fmt_num(v: number): string | |
if math.floor(v) == v then | |
local vlg2 = math.log(v, 2) | |
if v > 0 and vlg2 == math.floor(vlg2) then | |
if v < 1024 then return string.format("%d ", v) end; v /= 1024 | |
if v < 1024 then return string.format("%dK", v) end; v /= 1024 | |
if v < 1024 then return string.format("%dM", v) end; v /= 1024 | |
return string.format("%dG", v) | |
else | |
return string.format("%d", v) | |
end | |
end | |
return string.format("%g", v) | |
end | |
--stylua: ignore | |
local function _fmt_arg(arg: any): string | |
local t = typeof(arg) | |
if t == "nil" then return "" end | |
if t == "string" then return if #arg > 10 then string.format("%q...", arg:sub(1, 7)) else string.format("%q", arg) end | |
if t == "number" then return _fmt_num(arg) end | |
if t == "table" then | |
local c = 0; for _ in arg do c += 1 end | |
return if c == #arg then `array[{_fmt_num(c)}]` else `<table:{_fmt_num(c)}>` | |
end | |
if t == "buffer" then return `<buffer:{_fmt_num(buffer.len(arg))}>` end | |
return tostring(arg) | |
end | |
--- Summary statistics produced by the benchmark. | |
--- @field sec_per_iter number Estimated seconds per iteration | |
--- @field rel_noise number Approximate relative standard deviation | |
--- @field total_iter number Total iterations executed across samples | |
--- @field warn? string Optional warning marker ("<timeout>" or "<unstable>") | |
export type BenchSummary = { | |
sec_per_iter: number, | |
rel_noise: number, | |
total_iter: number, | |
warn: string, | |
} | |
--- Comparison result for a single routine. | |
--- @field tag string Label for this routine | |
--- @field summary BenchSummary Benchmark results | |
export type CompareResult = { | |
tag: string, | |
summary: BenchSummary, | |
} | |
local function fmt_summary(s: BenchSummary): string | |
local warn = if s.warn ~= "" then " " .. s.warn else "" | |
return string.format( | |
"%9s/iter ±%.2f%%%s", | |
_fmt_time(s.sec_per_iter), | |
s.rel_noise * 100.0, | |
warn | |
) | |
end | |
--- Run a quick benchmark and print a one-line summary. | |
--- @param tag string Label for the benchmark row. | |
--- @param routine fun(arg:any):any Function under test; returns are discarded. | |
--- @param arg? any Optional argument passed to the routine. | |
function m.bench.run(tag: string, routine: Fn, arg: any?) | |
local summary = _bench(routine, arg) | |
local arg_fmt = _fmt_arg(arg) | |
local label = if arg_fmt ~= "" then tag .. " " .. arg_fmt else tag | |
PRINTER(string.format("%-40s %*", label, fmt_summary(summary))) | |
end | |
--- Measure a benchmark without printing results. | |
--- @param routine fun(arg:any):any Function under test. | |
--- @param arg? any Optional argument passed to the routine. | |
--- @return BenchSummary Summary of timing results. | |
function m.bench.measure(routine: Fn, arg: any?): BenchSummary | |
return _bench(routine, arg) | |
end | |
--- @deprecated alias for backward compatibility | |
m.bench.sample = m.bench.measure | |
--- Configure benchmark timing parameters. | |
--- @param min_duration? number Minimum benchmark duration in seconds (default: 0.1). | |
--- @param max_duration? number Maximum benchmark duration in seconds (default: 5.0). | |
--- @param target_rel_stddev? number Target relative standard deviation (default: 0.02). | |
function m.bench.configure(min_duration: number?, max_duration: number?, target_rel_stddev: number?) | |
if min_duration then MIN_BENCH_DURATION = min_duration end | |
if max_duration then MAX_BENCH_DURATION = max_duration end | |
if target_rel_stddev then TARGET_REL_STDDEV = target_rel_stddev end | |
end | |
--- Get current benchmark configuration. | |
--- @return number min_duration Minimum benchmark duration in seconds. | |
--- @return number max_duration Maximum benchmark duration in seconds. | |
--- @return number target_rel_stddev Target relative standard deviation. | |
function m.bench.get_config(): (number, number, number) | |
return MIN_BENCH_DURATION, MAX_BENCH_DURATION, TARGET_REL_STDDEV | |
end | |
--- Measure and compare multiple routines without printing. | |
--- @param arg? any Argument passed to all routines. | |
--- @param routines {[string]: Fn} Table mapping tags to routines. | |
--- @return {CompareResult} Array of results sorted by speed (fastest first). | |
function m.bench.measure_compare(arg: any?, routines: { [string]: Fn }): { CompareResult } | |
-- Validate input | |
if type(routines) ~= "table" then | |
error(`Expected routines to be a table, got {type(routines)}`, 2) | |
end | |
-- Collect benchmark results | |
local results: { CompareResult } = {} | |
for tag, routine in routines do | |
if type(routine) ~= "function" then | |
error(`Expected routine for "{tag}" to be a function, got {type(routine)}`, 2) | |
end | |
local summary = _bench(routine, arg) | |
table.insert(results, { tag = tag, summary = summary }) | |
end | |
-- Sort by speed (fastest first) | |
table.sort(results, function(a, b) return a.summary.sec_per_iter < b.summary.sec_per_iter end) | |
return results | |
end | |
--- Compare multiple routines with the same argument and print results. | |
--- @param arg? any Argument passed to all routines. | |
--- @param routines {[string]: Fn} Table mapping tags to routines. | |
function m.bench.compare(arg: any?, routines: { [string]: Fn }) | |
local results = m.bench.measure_compare(arg, routines) | |
-- Handle edge cases | |
if #results == 0 then | |
PRINTER("No routines to compare") | |
return | |
end | |
if #results == 1 then | |
local result = results[1] | |
local arg_fmt = _fmt_arg(arg) | |
local label = if arg_fmt ~= "" then result.tag .. " " .. arg_fmt else result.tag | |
PRINTER(string.format("%-40s %*", label, fmt_summary(result.summary))) | |
return | |
end | |
-- Print header | |
local arg_fmt = _fmt_arg(arg) | |
local header = if arg_fmt ~= "" then `Comparing with {arg_fmt}:` else "Comparing:" | |
PRINTER(header) | |
-- Print results with comparison | |
local baseline = results[1].summary.sec_per_iter | |
for i, result in results do | |
local tag = result.tag | |
local summary = result.summary | |
local time_str = fmt_summary(summary) | |
if i == 1 then | |
-- Fastest is baseline | |
PRINTER(string.format(" %-38s %* [baseline]", tag, time_str)) | |
else | |
-- Show relative performance | |
local ratio = summary.sec_per_iter / baseline | |
local rel_str = string.format("(%.2fx slower)", ratio) | |
PRINTER(string.format(" %-38s %* %s", tag, time_str, rel_str)) | |
end | |
end | |
end | |
end | |
--[[ quick demo | |
-- benchmark demo | |
local function fib(n: number): number | |
if n <= 1 then return n end | |
return fib(n - 1) + fib(n - 2) | |
end | |
for arg = 20, 30, 2 do | |
m.bench.run("fibbonacci of", fib, arg) | |
end | |
-- test demo | |
local suite = m.suite "example test suite" | |
suite | |
:test("example test") | |
:case("#1", function() m.expect.truthy "A" end) | |
:case("#2", function() m.expect.equal(false, true) end) -- will fail | |
:case("#3", function() m.expect.near_equal(1, 2, 0.1) end) -- will fail | |
suite | |
:test("another example test") | |
:case("#11", function() m.expect.deep_equal({ "A" }, { "A" }) end) | |
:case("#12", function() m.expect.type("A", "string") end) | |
:case("#13", function() assert(nil, "Expect not to be nil") end) -- plain assert is ok | |
suite | |
:test("single test example") | |
:case("#21", function() | |
m.expect.throws(function() error "out of range" end, "range") | |
end) | |
:case("#22", function() m.expect.not_nil(suite) end) | |
suite | |
:test("extras: fail & not_throws") | |
:case("#31 fail throws", function() | |
m.expect.throws(function() m.expect.fail "boom" end, "boom") | |
end) | |
:case("#32 not_throws ok", function() | |
m.expect.not_throws(function() end) | |
end) | |
:case("#33 not_throws fails", function() | |
m.expect.throws(function() | |
m.expect.not_throws(function() error "kaboom" end) | |
end, "not to throw") | |
end) | |
m.set_ci_mode(false) -- true to set CI mode | |
suite:run "normal" | |
--]] | |
-- stylua: ignore end | |
return m |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment