Skip to content

Instantly share code, notes, and snippets.

@sheganinans
Last active April 13, 2021 02:07
Show Gist options
  • Save sheganinans/319c336875f13af6df463222a9a440f3 to your computer and use it in GitHub Desktop.
Save sheganinans/319c336875f13af6df463222a9a440f3 to your computer and use it in GitHub Desktop.
Python Pattern Matching in Python
"""
This is an implementation of a simple (< 100LoC) *runtime* pattern matching system for Python.
`pm` is a function that takes two arguments, the object being matched over and some rule for that object.
If the pattern match fails, then `pm` returns `False`.
If the pattern match succeeds and there are no captured `Var`iables, then `pm` returns `True`.
Otherwise, `pm` will return a dict with the captured `Var`iables and their associated values.
This system will match and extract arbitrarily nested patterns on common Python data types.
However this pattern matching is all done at runtime and has it's associated cost.
See the included test suite for details on expected behavior.
It is the responsibility of the caller that all of the `Var`s given are unique.
-- Aistis Raulinaitis
"""
from typing import Any, Union
@dataclass
class Var:
var: str = field(default_factory=str)
def __hash__(self) -> int:
return self.var.__hash__()
def pm_list(obj, rule):
if len(obj) != len(rule):
return False
ret = {}
for (o, r) in zip(obj, rule):
if type(r) == Var:
ret[r] = o
else:
pm_ret = pm(o, r)
if pm_ret:
if type(pm_ret) == dict:
ret = {**ret, **pm_ret}
else:
return False
return ret if ret != {} else True
def no_vars(rule):
return filter(lambda x: type(x) != Var, rule)
def pm_dict(obj, rule):
if len(obj) != len(rule):
return False
if all([k in obj.keys() for k in no_vars(rule.keys())]):
pm_rets = [pm(obj[k], rule[k]) for k in no_vars(rule.keys())]
if all(pm_rets):
ret = {}
for k, v in filter(lambda kv: type(kv[1]) == Var, rule.items()):
ret[v] = obj[k]
for pm_ret in pm_rets:
if type(pm_ret) == dict:
ret = {**ret, **pm_ret}
return ret if ret != {} else True
return False
def pm_set(obj, rule):
if len(obj) != len(rule):
return False
if all([x in obj for x in no_vars(rule)]):
vs = list(filter(lambda x: type(x) == Var, rule))
if len(vs) > 1:
return False
ret = {}
for v, non_const in zip(vs, filter(lambda x: not x in rule, obj)):
ret[v] = non_const
return ret if ret != {} else True
return False
def pm(obj: Any, rule: Any) -> Union[dict, bool]:
obj_t, rule_t = type(obj), type(rule)
if rule_t == Var:
return {rule: obj}
if rule_t in [int, float, str]:
return obj == rule
if obj_t != rule_t:
return False
if rule_t == list:
return pm_list(obj, rule)
if rule_t == dict:
return pm_dict(obj, rule)
if rule_t == set:
return pm_set(obj, rule)
return False
def test_pm():
assert pm(1, 1)
assert not pm(1, 2)
assert pm(1.0, 1)
assert not pm(1.0, 2)
assert pm("a", "a")
assert not pm("a", "b")
assert pm(1, Var("x")) == {Var("x"): 1}
assert pm([], [])
assert pm([1], [1])
assert not pm([1], [2])
assert pm([1], [Var("x")]) == {Var("x"): 1}
assert not pm([1], [Var("x"), 2])
# Recursive patterns
assert pm([1, 2, 3], [1, Var("x"), 3]) == {Var("x"): 2}
assert pm([1, [], 3], [1, Var("x"), 3]) == {Var("x"): []}
assert pm([1, [2], 3], [1, Var("x"), 3]) == {Var("x"): [2]}
assert pm([1, [2], 3], [1, [Var("x")], 3]) == {Var("x"): 2}
assert pm([1, [2], 3], [1, [2], 3])
assert not pm([1, [2], 3], [1, ["2"], 3])
assert pm({}, {})
assert pm({"a": 1}, {"a": 1})
assert pm({"a": 1}, {"a": Var("a")}) == {Var("a"): 1}
assert not pm({"a": 1}, {"a": 1, "b": 2})
assert pm({"b": 2, "a": 1}, {"a": 1, "b": 2})
assert pm({"b": 2, "a": 1}, {"a": 1, "b": Var("b")}) == {Var("b"): 2}
assert pm({"a": 1, "b": {"c": "d"}}, {"a": 1, "b": {"c": "d"}})
assert not pm({"a": 1, "b": {"c": "d"}}, {"a": 1, "b": {"c": "e"}})
assert pm({"a": 1, "b": {"c": "d"}}, {"a": 1, "b": {"c": Var("e")}}) == {Var("e"): "d"}
assert not pm({"b": 2, "a": 1}, {"a": 1, "b": Var("b"), "c": 3})
assert pm({1}, {1})
assert pm({1}, {Var("b")}) == {Var("b"): 1}
assert not pm({1}, {Var("b"), 2})
assert pm({1, 2}, {Var("b"), 2}) == {Var("b"): 1}
assert pm({1, "abc"}, {Var("b"), "abc"}) == {Var("b"): 1}
# Sets are unordered, so thus have no simple way to make a unique match with multiple variables.
assert not pm({1, "abc"}, {Var("b"), Var("a")})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment