Last active
April 29, 2021 15:55
-
-
Save DDoSolitary/0a1fdcd1a3a5f715f429ed7bc11b6d6b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import functools | |
import json | |
import math | |
import multiprocessing | |
import random | |
import re | |
import subprocess | |
import sys | |
from sympy import sin, cos, diff, lambdify, expand | |
from sympy.abc import x | |
from sympy.parsing import parse_expr | |
MAX_DIGIT_COUNT = 2 | |
MAX_FACTOR_COUNT = 3 | |
MAX_RECURSION_DEPTH = 4 | |
PATTERN_LEADING_ZEROS = re.compile(r'(?:^|(?<=[^0-9]))0+(?=[0-9])') | |
def gen_ws(): | |
return random.choice(('', '', ' ', '\t')) | |
def gen_opt_sign(): | |
return random.choice(('', '+', '-')) | |
def gen_special_const(): | |
return random.choices( | |
((0, '-0'), (0, '0'), (0, '+0'), (-1, '-1'), (1, '1'), (1, '+1')), | |
weights=(2, 2, 2, 6, 3, 3) | |
)[0] | |
def gen_normal_const(): | |
sign = gen_opt_sign() | |
digit_count = random.randrange(MAX_DIGIT_COUNT) + 1 | |
digits = ''.join(str(random.randrange(10)) for _ in range(digit_count)) | |
expr_str = sign + digits | |
return int(expr_str), expr_str | |
def gen_const(): | |
return random.choice((gen_special_const, gen_normal_const))() | |
def gen_pow(): | |
return x, 'x' | |
def gen_sin(expr, expr_str): | |
return sin(expr), f'sin{gen_ws()}({expr_str})' | |
def gen_cos(expr, expr_str): | |
return cos(expr), f'cos{gen_ws()}({expr_str})' | |
def gen_trig(depth): | |
expr, expr_str = gen_factor(depth + 1) | |
return random.choice([gen_sin, gen_cos])(expr, expr_str) | |
def gen_var(depth): | |
expr, expr_str = random.choice([gen_pow, functools.partial(gen_trig, depth=depth)])() | |
# exp, exp_str = gen_const() | |
exp = random.randrange(-50, 51) | |
exp_str = str(exp) | |
expr = expr ** exp | |
if exp != 1 or random.randrange(2) == 0: | |
expr_str += f'{gen_ws()}**{gen_ws()}{exp_str}' | |
return expr, expr_str | |
def gen_nested_expr(depth): | |
expr, expr_str = gen_expr(depth=depth + 1) | |
return expr, f'({expr_str})' | |
def gen_factor(depth): | |
gen_funcs = [gen_const, functools.partial(gen_var, depth=depth)] | |
if depth < MAX_RECURSION_DEPTH: | |
gen_funcs.append(functools.partial(gen_nested_expr, depth=depth)) | |
return random.choice(gen_funcs)() | |
def gen_expr(depth=1): | |
factor_count = random.randrange(MAX_FACTOR_COUNT) + 1 | |
term_count = random.randrange(factor_count) + 1 | |
term_splits = [0] + sorted(random.sample(range(1, factor_count), term_count - 1)) + [factor_count] | |
term_sizes = [term_splits[i + 1] - term_splits[i] for i in range(term_count)] | |
expr = 0 | |
expr_str = '' | |
for term_size in term_sizes: | |
term_str = gen_opt_sign() | |
term = -1 if term_str == '-' else 1 | |
if len(term_str) > 0: | |
term_str += gen_ws() | |
for i in range(term_size): | |
factor, factor_str = gen_factor(depth) | |
term *= factor | |
if i > 0: | |
term_str += gen_ws() + '*' + gen_ws() | |
term_str += factor_str | |
if len(expr_str) == 0: | |
sign = gen_opt_sign() | |
else: | |
sign = random.choice(('+', '-')) | |
if sign == '-': | |
expr -= term | |
else: | |
expr += term | |
if len(sign) > 0: | |
expr_str += gen_ws() + sign + gen_ws() | |
else: | |
expr_str += gen_ws() | |
expr_str += term_str | |
return expr, expr_str + gen_ws() | |
def check_equal(f1, f2): | |
for _ in range(100): | |
for _ in range(5): | |
var = float(random.uniform(-10, 10)) | |
try: | |
val1 = f1(var) | |
val2 = f2(var) | |
except (ZeroDivisionError, OverflowError): | |
continue | |
if not math.isfinite(val1) or not math.isfinite(val2): | |
continue | |
if not math.isclose(val1, val2, rel_tol=1e-3): | |
return False, var | |
break | |
return True, None | |
def remove_leading_zeros(s): | |
return PATTERN_LEADING_ZEROS.sub('', s) | |
def do_fuzz(_, config): | |
if config.get('manual'): | |
input_expr_str = input() | |
input_expr = parse_expr(remove_leading_zeros(input_expr_str)) | |
else: | |
input_expr, input_expr_str = gen_expr() | |
input_sympy_str = str(input_expr) | |
if 'nan' in input_sympy_str or 'zoo' in input_sympy_str: | |
return [] | |
ans_expr = expand(diff(input_expr)) | |
ans_func = lambdify(x, ans_expr, 'numpy') | |
if config.get('debug'): | |
print(input_expr_str) | |
print(ans_expr) | |
subjects = config['subjects'] | |
errors = [] | |
for subject in subjects: | |
try: | |
proc = subprocess.run( | |
subject['cmd'], | |
input=input_expr_str, | |
capture_output=True, | |
text=True, | |
timeout=5 | |
) | |
except subprocess.TimeoutExpired: | |
errors.append(dict( | |
name=subject['name'], | |
reason='Time Limit Exceeded', | |
stdin=input_expr_str, | |
ans=str(ans_expr) | |
)) | |
continue | |
if proc.returncode != 0: | |
errors.append(dict( | |
name=subject['name'], | |
reason='Runtime Error', | |
stdin=input_expr_str, | |
stdout=proc.stdout, | |
stderr=proc.stderr, | |
ans=str(ans_expr) | |
)) | |
continue | |
if config.get('debug'): | |
print(f'{subject["name"]}: {proc.stdout}') | |
validator_proc = subprocess.run( | |
config['validator'], | |
input=proc.stdout, | |
capture_output=True, | |
text=True | |
) | |
if validator_proc.returncode != 0: | |
errors.append(dict( | |
name=subject['name'], | |
reason='Wrong Answer (invalid output)', | |
stdin=input_expr_str, | |
stdout=proc.stdout, | |
ans=str(ans_expr), | |
validator_stdout=validator_proc.stdout | |
)) | |
continue | |
output_expr = expand(parse_expr(remove_leading_zeros(proc.stdout))) | |
check_res, check_var = check_equal(lambdify(x, output_expr, 'numpy'), ans_func) | |
if not check_res: | |
errors.append(dict( | |
name=subject['name'], | |
reason='Wrong Answer', | |
stdin=input_expr_str, | |
stdout=proc.stdout, | |
ans=str(ans_expr), | |
var=check_var | |
)) | |
return errors | |
def compile_rule_item(k, v): | |
if k == 'action': | |
return k, v | |
else: | |
return k, re.compile(v) | |
def compile_rule(rule): | |
return dict(compile_rule_item(k, v) for k, v in rule.items()) | |
def filter_error(err, rules): | |
for rule in rules: | |
matched = True | |
for key, pattern in rule.items(): | |
if key == 'action': | |
continue | |
value = err.get(key) | |
if value is None or pattern.search(value) is None: | |
matched = False | |
break | |
if matched: | |
action = rule['action'] | |
if action == 'accept': | |
return True | |
elif action == 'ignore': | |
return False | |
return True | |
def main(): | |
with open(sys.argv[1]) as f: | |
config = json.load(f) | |
if config.get('manual'): | |
errors = do_fuzz(None, config=config) | |
else: | |
errors = [] | |
with multiprocessing.Pool() as pool: | |
idx = 0 | |
for res in pool.imap_unordered(functools.partial(do_fuzz, config=config), range(config['count'])): | |
print(f'#{idx}: {len(res)}') | |
idx += 1 | |
errors.extend(res) | |
filter_rules = list(map(compile_rule, config.get('filters', []))) | |
errors = list(filter(functools.partial(filter_error, rules=filter_rules), errors)) | |
print(json.dumps(errors, indent=2)) | |
if __name__ == '__main__': | |
main() | |
# vim: ts=4:sw=4:noet |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment