Last active
April 13, 2021 02:07
-
-
Save sheganinans/319c336875f13af6df463222a9a440f3 to your computer and use it in GitHub Desktop.
Python Pattern Matching in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is an implementation of a simple (< 100LoC) *runtime* pattern matching system for Python. | |
`pm` is a function that takes two arguments, the object being matched over and some rule for that object. | |
If the pattern match fails, then `pm` returns `False`. | |
If the pattern match succeeds and there are no captured `Var`iables, then `pm` returns `True`. | |
Otherwise, `pm` will return a dict with the captured `Var`iables and their associated values. | |
This system will match and extract arbitrarily nested patterns on common Python data types. | |
However this pattern matching is all done at runtime and has it's associated cost. | |
See the included test suite for details on expected behavior. | |
It is the responsibility of the caller that all of the `Var`s given are unique. | |
-- Aistis Raulinaitis | |
""" | |
from typing import Any, Union | |
@dataclass | |
class Var: | |
var: str = field(default_factory=str) | |
def __hash__(self) -> int: | |
return self.var.__hash__() | |
def pm_list(obj, rule): | |
if len(obj) != len(rule): | |
return False | |
ret = {} | |
for (o, r) in zip(obj, rule): | |
if type(r) == Var: | |
ret[r] = o | |
else: | |
pm_ret = pm(o, r) | |
if pm_ret: | |
if type(pm_ret) == dict: | |
ret = {**ret, **pm_ret} | |
else: | |
return False | |
return ret if ret != {} else True | |
def no_vars(rule): | |
return filter(lambda x: type(x) != Var, rule) | |
def pm_dict(obj, rule): | |
if len(obj) != len(rule): | |
return False | |
if all([k in obj.keys() for k in no_vars(rule.keys())]): | |
pm_rets = [pm(obj[k], rule[k]) for k in no_vars(rule.keys())] | |
if all(pm_rets): | |
ret = {} | |
for k, v in filter(lambda kv: type(kv[1]) == Var, rule.items()): | |
ret[v] = obj[k] | |
for pm_ret in pm_rets: | |
if type(pm_ret) == dict: | |
ret = {**ret, **pm_ret} | |
return ret if ret != {} else True | |
return False | |
def pm_set(obj, rule): | |
if len(obj) != len(rule): | |
return False | |
if all([x in obj for x in no_vars(rule)]): | |
vs = list(filter(lambda x: type(x) == Var, rule)) | |
if len(vs) > 1: | |
return False | |
ret = {} | |
for v, non_const in zip(vs, filter(lambda x: not x in rule, obj)): | |
ret[v] = non_const | |
return ret if ret != {} else True | |
return False | |
def pm(obj: Any, rule: Any) -> Union[dict, bool]: | |
obj_t, rule_t = type(obj), type(rule) | |
if rule_t == Var: | |
return {rule: obj} | |
if rule_t in [int, float, str]: | |
return obj == rule | |
if obj_t != rule_t: | |
return False | |
if rule_t == list: | |
return pm_list(obj, rule) | |
if rule_t == dict: | |
return pm_dict(obj, rule) | |
if rule_t == set: | |
return pm_set(obj, rule) | |
return False | |
def test_pm(): | |
assert pm(1, 1) | |
assert not pm(1, 2) | |
assert pm(1.0, 1) | |
assert not pm(1.0, 2) | |
assert pm("a", "a") | |
assert not pm("a", "b") | |
assert pm(1, Var("x")) == {Var("x"): 1} | |
assert pm([], []) | |
assert pm([1], [1]) | |
assert not pm([1], [2]) | |
assert pm([1], [Var("x")]) == {Var("x"): 1} | |
assert not pm([1], [Var("x"), 2]) | |
# Recursive patterns | |
assert pm([1, 2, 3], [1, Var("x"), 3]) == {Var("x"): 2} | |
assert pm([1, [], 3], [1, Var("x"), 3]) == {Var("x"): []} | |
assert pm([1, [2], 3], [1, Var("x"), 3]) == {Var("x"): [2]} | |
assert pm([1, [2], 3], [1, [Var("x")], 3]) == {Var("x"): 2} | |
assert pm([1, [2], 3], [1, [2], 3]) | |
assert not pm([1, [2], 3], [1, ["2"], 3]) | |
assert pm({}, {}) | |
assert pm({"a": 1}, {"a": 1}) | |
assert pm({"a": 1}, {"a": Var("a")}) == {Var("a"): 1} | |
assert not pm({"a": 1}, {"a": 1, "b": 2}) | |
assert pm({"b": 2, "a": 1}, {"a": 1, "b": 2}) | |
assert pm({"b": 2, "a": 1}, {"a": 1, "b": Var("b")}) == {Var("b"): 2} | |
assert pm({"a": 1, "b": {"c": "d"}}, {"a": 1, "b": {"c": "d"}}) | |
assert not pm({"a": 1, "b": {"c": "d"}}, {"a": 1, "b": {"c": "e"}}) | |
assert pm({"a": 1, "b": {"c": "d"}}, {"a": 1, "b": {"c": Var("e")}}) == {Var("e"): "d"} | |
assert not pm({"b": 2, "a": 1}, {"a": 1, "b": Var("b"), "c": 3}) | |
assert pm({1}, {1}) | |
assert pm({1}, {Var("b")}) == {Var("b"): 1} | |
assert not pm({1}, {Var("b"), 2}) | |
assert pm({1, 2}, {Var("b"), 2}) == {Var("b"): 1} | |
assert pm({1, "abc"}, {Var("b"), "abc"}) == {Var("b"): 1} | |
# Sets are unordered, so thus have no simple way to make a unique match with multiple variables. | |
assert not pm({1, "abc"}, {Var("b"), Var("a")}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment