Skip to content

Instantly share code, notes, and snippets.

@eriknw
Created September 2, 2024 10:52
Show Gist options
  • Save eriknw/542201420efa45e954435bf86fb2a55d to your computer and use it in GitHub Desktop.
Save eriknw/542201420efa45e954435bf86fb2a55d to your computer and use it in GitHub Desktop.
Dispatch by types example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d8840480-60dc-4a92-9a8c-54dfc5373677",
"metadata": {},
"outputs": [],
"source": [
"import inspect\n",
"import numbers\n",
"from functools import partial\n",
"\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1f694dd8-befd-4dd6-9d31-dbfb618fffae",
"metadata": {},
"outputs": [],
"source": [
"# Mock backend information\n",
"def foo_match_ducktypes(func_name, dispatchable_objects):\n",
" return all(isinstance(val, numbers.Integral) for val in dispatchable_objects.values())\n",
"\n",
"\n",
"def foo_match_generic(func_name, args, kwargs):\n",
" return False, \"give a reason here\"\n",
"\n",
"\n",
"backend_info = {\n",
" \"foo\": {\n",
" # Should we allow backends to define priority?\n",
" \"default_priority\": 1,\n",
" \"types\": {\"__main__:A\"},\n",
" \"secondary_types\": {\"builtins:int\"},\n",
" # `match_ducktypes` and `match_generic` should try to be fast and avoid unnecessary imports\n",
" \"match_ducktypes\": foo_match_ducktypes,\n",
" \"match_generic\": foo_match_generic,\n",
" \"functions\": {\"f1\", \"f2\"},\n",
" # \"url\", etc\n",
" },\n",
" \"bar\": {\n",
" \"default_priority\": 2,\n",
" \"types\": {\"__main__:B\", \"builtins:float\"},\n",
" \"secondary_types\": {\"builtins:complex\"},\n",
" \"functions\": {\"f1\", \"f3\"},\n",
" },\n",
"}\n",
"\n",
"# TODO: try to get from environment variable\n",
"default_backend_priority = tuple(\n",
" sorted(backend_info, key=lambda x: (backend_info[x].get(\"default_priority\", 0), x))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "05337643-3df9-47f2-b2d6-76292a05a1b1",
"metadata": {},
"outputs": [],
"source": [
"# {frozenset(type_strings): [backend_names]}\n",
"class inner_function_priority_dict(dict):\n",
" def __init__(self, func_name):\n",
" self.func_name = func_name\n",
"\n",
" def __missing__(self, type_strings):\n",
" val = self[type_strings] = tuple(\n",
" backend\n",
" for backend in default_backend_priority\n",
" if self.func_name in (info := backend_info[backend])[\"functions\"]\n",
" and type_strings.issubset(info[\"types\"])\n",
" )\n",
" return val\n",
"\n",
"\n",
"# {func_name: {frozenset(type_strings): [backend_names]}}\n",
"class function_priority_dict_secondary(dict):\n",
" def __missing__(self, func_name):\n",
" val = self[func_name] = inner_function_priority_dict_secondary(func_name)\n",
" return val\n",
"\n",
"\n",
"# {frozenset(type_strings): [backend_names]}\n",
"class inner_function_priority_dict_secondary(dict):\n",
" def __init__(self, func_name):\n",
" self.func_name = func_name\n",
"\n",
" def __missing__(self, type_strings):\n",
" val = self[type_strings] = tuple(\n",
" backend\n",
" for backend in default_backend_priority\n",
" if self.func_name in (info := backend_info[backend])[\"functions\"]\n",
" and not type_strings.issubset(info[\"types\"])\n",
" and type_strings.issubset(info[\"types\"] | info[\"secondary_types\"])\n",
" )\n",
" return val\n",
"\n",
"# {func_name: {frozenset(type_strings): [backend_names]}}\n",
"class function_priority_dict(dict):\n",
" def __missing__(self, func_name):\n",
" val = self[func_name] = inner_function_priority_dict(func_name)\n",
" return val\n",
"\n",
"\n",
"# {func_name: {backend_name: match_ducktypes}}\n",
"class function_match_ducktypes_dict(dict):\n",
" def __missing__(self, func_name):\n",
" # The iteration order of `val` dict is in priority order\n",
" val = self[func_name] = {\n",
" backend: info[\"match_ducktypes\"]\n",
" for backend in default_backend_priority\n",
" if \"match_ducktypes\" in (info := backend_info[backend])\n",
" and func_name in info[\"functions\"]\n",
" }\n",
" return val\n",
"\n",
"\n",
"func_to_backend_priority = function_priority_dict()\n",
"func_to_backend_priority_secondary = function_priority_dict_secondary()\n",
"func_to_backend_match_ducktypes = function_match_ducktypes_dict()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7cbfba94-aafa-48df-b9da-dd13ef0c7ec9",
"metadata": {},
"outputs": [],
"source": [
"# Allow users to interact with priority of backends\n",
"# This is super-basic. Perhaps improve.\n",
"_user_priority = default_backend_priority\n",
"\n",
"\n",
"def get_backend_priority():\n",
" return _user_priority\n",
"\n",
"\n",
"def set_backend_priority(backend_priority):\n",
" global _user_priority\n",
" if not isinstance(backend_priority, tuple):\n",
" if isinstance(backend_priority, str):\n",
" backend_priority = (backend_priority,)\n",
" else:\n",
" backend_priority = tuple(backend_priority)\n",
" if invalid_backends := {backend for backend in backend_priority if backend not in backend_info}:\n",
" invalid_str = \", \".join(sorted(invalid_backends))\n",
" raise ValueError(f\"Invalid backend (perhaps the backend is not installed): {invalid_str}\")\n",
" if (\n",
" len(backend_priority) == len(default_backend_priority)\n",
" and backend_priority == default_backend_priority\n",
" ):\n",
" prev, _user_priority = _user_priority, default_backend_priority\n",
" else:\n",
" prev, _user_priority = _user_priority, backend_priority\n",
" return prev"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0b130222-8d79-40e0-bb1b-40afe19b679c",
"metadata": {},
"outputs": [],
"source": [
"dispatchable_functions = {} # {func_name: dispatchable_function}\n",
"\n",
"class dispatchable:\n",
" _default_backend = \"default\"\n",
"\n",
" def __new__(cls, dispatchable_objects, *, name=None):\n",
" return partial(cls._from_func, dispatchable_objects, name=name)\n",
"\n",
" @classmethod\n",
" def _from_func(cls, dispatchable_objects, default_func, *, name=None):\n",
" self = object.__new__(cls)\n",
" if name is None:\n",
" name = default_func.__name__\n",
" if isinstance(dispatchable_objects, str):\n",
" dispatchable_objects = {dispatchable_objects: 0}\n",
" elif isinstance(dispatchable_objects, list | tuple):\n",
" dispatchable_objects = {val: i for i, val in enumerate(dispatchable_objects)}\n",
" self.dispatchable_objects = dispatchable_objects\n",
" self.name = name\n",
" self.default_func = default_func\n",
" self.backends = frozenset(\n",
" backend for backend, info in backend_info.items() if name in info[\"functions\"]\n",
" )\n",
" self._sig = None # Signature will be cached\n",
"\n",
" # Let's return a normal Python function to be nicer\n",
" if backend_info:\n",
"\n",
" def dispatchable_func(*args, **kwargs):\n",
" return self(*args, **kwargs)\n",
"\n",
" # TODO(?): experiment hiding dispatching from the traceback\n",
" # def dispatchable_func(*args, **kwargs):\n",
" # try:\n",
" # return self(*args, **kwargs)\n",
" # except Exception as exc:\n",
" # try:\n",
" # exc = exc.with_traceback(exc.__traceback__.tb_next.tb_next)\n",
" # except Exception:\n",
" # pass\n",
" # raise exc\n",
"\n",
" else:\n",
" # Use a simpler, faster `__call__` if there are no backends installed.\n",
" # A library may also return `default_func` here if they prefer.\n",
" def dispatchable_func(*args, backend=None, **kwargs):\n",
" if backend is not None and backend != self._default_backend:\n",
" raise ImportError(f\"'{backend}' backend is not installed\")\n",
" return default_func(*args, **kwargs)\n",
"\n",
" # Standard function-wrapping stuff; faster than calling `functools.update_wrapper`\n",
" # What about __annotations__?\n",
" self.__dict__.update(default_func.__dict__)\n",
" dispatchable_func.__dict__.update(self.__dict__)\n",
" dispatchable_func.__defaults__ = self.__defaults__ = default_func.__defaults__\n",
" # We \"magically\" add `backend=` keyword argument to allow backend to be specified\n",
" if dispatchable_func.__kwdefaults__:\n",
" dispatchable_func.__kwdefaults__ = self.__kwdefaults__ = {\n",
" **default_func.__kwdefaults__,\n",
" \"backend\": None,\n",
" }\n",
" else:\n",
" dispatchable_func.__kwdefaults__ = self.__kwdefaults__ = {\"backend\": None}\n",
" dispatchable_func.__module__ = self.__module__ = default_func.__module__\n",
" dispatchable_func.__name__ = self.__name__ = default_func.__name__\n",
" dispatchable_func.__qualname__ = self.__qualname__ = default_func.__qualname__\n",
" self.__wrapped__ = default_func\n",
" dispatchable_func.__wrapped__ = self\n",
"\n",
" # TODO: __doc__\n",
"\n",
" # To expose methods to users, assign them to `dispatchable_func`:\n",
" dispatchable_func.call_with_backend = self.call_with_backend\n",
" dispatchable_func.match_backends = self.match_backends\n",
" dispatchable_func.can_backend_run = self.can_backend_run\n",
" dispatchable_func.should_backend_run = self.should_backend_run\n",
" dispatchable_func.get_dispatchable_objects = self.get_dispatchable_objects\n",
" dispatchable_func.get_dispatchable_types = self.get_dispatchable_types\n",
"\n",
" if name in dispatchable_functions:\n",
" raise KeyError(f\"'{name}' already registered as a dispatchable function\")\n",
" dispatchable_functions[name] = dispatchable_func\n",
" return dispatchable_func\n",
"\n",
" @property\n",
" def __signature__(self):\n",
" if self._sig is None:\n",
" sig = inspect.signature(self.default_func)\n",
" # `backend` is now a reserved argument used by dispatching.\n",
" # assert \"backend\" not in sig.parameters\n",
" if not any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):\n",
" sig = sig.replace(\n",
" parameters=[\n",
" *sig.parameters.values(),\n",
" inspect.Parameter(\"backend\", inspect.Parameter.KEYWORD_ONLY, default=None),\n",
" ]\n",
" )\n",
" else:\n",
" *parameters, var_keyword = sig.parameters.values()\n",
" sig = sig.replace(\n",
" parameters=[\n",
" *parameters,\n",
" inspect.Parameter(\"backend\", inspect.Parameter.KEYWORD_ONLY, default=None),\n",
" var_keyword,\n",
" ]\n",
" )\n",
" self._sig = sig\n",
" return self._sig\n",
"\n",
" def get_dispatchable_types(self, *args, **kwargs):\n",
" return frozenset(\n",
" f\"{val.__class__.__module__}:{val.__class__.__qualname__}\"\n",
" for name, pos in self.dispatchable_objects.items()\n",
" if (val := args[pos] if pos < len(args) else kwargs.get(name)) is not None\n",
" )\n",
"\n",
" def get_dispatchable_objects(self, *args, **kwargs):\n",
" return {\n",
" name: val\n",
" for name, pos in self.dispatchable_objects.items()\n",
" if (val := args[pos] if pos < len(args) else kwargs.get(name))\n",
" }\n",
"\n",
" def match_backends(self, args, kwargs, priorities=None):\n",
" \"\"\"Yield ``(name, how)`` tuple of backends that match call with given arguments.\n",
"\n",
" ``how`` may be \"primary\", \"secondary\", \"duck\", \"generic\", or \"default\".\n",
"\n",
" This first matches by exact types (\"primary\" then \"secondary\"), then by using\n",
" functions provided by backends (\"duck\" then \"generic\").\n",
" \"\"\"\n",
" # This handles dispatchable arguments passed as positional or keyword,\n",
" # but this doesn't catch the (hopefully unlikely) case of an argument\n",
" # passed as *both* positional and keyword arguments (will fail later).\n",
" if priorities is not None and priorities is default_backend_priority:\n",
" priorities = None\n",
" dispatchable_objects = self.get_dispatchable_objects(*args, **kwargs)\n",
" dispatchable_types = frozenset(\n",
" f\"{val.__class__.__module__}:{val.__class__.__qualname__}\"\n",
" for val in dispatchable_objects.values()\n",
" if val is not None\n",
" )\n",
" # First try to match on exact types\n",
" attempted_backends = set()\n",
" if dispatchable_types:\n",
" backend_priority = func_to_backend_priority[self.name][dispatchable_types]\n",
" if priorities is not None and backend_priority:\n",
" backend_priority_set = set(backend_priority)\n",
" backend_priority = (\n",
" backend for backend in priorities if backend in backend_priority_set\n",
" )\n",
" for backend in backend_priority:\n",
" yield backend, \"primary\"\n",
" attempted_backends.add(backend)\n",
"\n",
" backend_priority = func_to_backend_priority_secondary[self.name][dispatchable_types]\n",
" if priorities is not None and backend_priority:\n",
" backend_priority_set = set(backend_priority)\n",
" backend_priority = (\n",
" backend for backend in priorities if backend in backend_priority_set\n",
" )\n",
" for backend in backend_priority:\n",
" yield backend, \"secondary\"\n",
" attempted_backends.add(backend)\n",
"\n",
" # Then use generic matching function and is primarily for handling duck-types\n",
" # `match_ducktypes` from the backend should try to be fast and avoid unnecessary imports\n",
" backend_match_ducktypes = func_to_backend_match_ducktypes[self.name]\n",
" if priorities is not None and backend_match_ducktypes:\n",
" backend_priority_set = set(backend_priority) - attempted_backends\n",
" backend_priority = (\n",
" backend for backend in priorities if backend in backend_priority_set\n",
" )\n",
" else:\n",
" backend_priority = (\n",
" backend for backend in backend_match_ducktypes if backend not in attempted_backends\n",
" )\n",
" for backend in backend_priority:\n",
" if backend_match_ducktypes[backend](self.name, dispatchable_objects):\n",
" yield backend, \"duck\"\n",
" attempted_backends.add(backend)\n",
"\n",
" backend_priority = default_backend_priority if priorities is None else priorities\n",
" # bound = None\n",
" for backend in backend_priority:\n",
" if (\n",
" backend not in attempted_backends\n",
" and \"match_generic\" in (info := backend_info[backend])\n",
" and self.name in info[\"functions\"]\n",
" ):\n",
" # if bound is None:\n",
" # bound = self.__signature__.bind(*args, **kwargs)\n",
" # bound.apply_defaults()\n",
" # del bound.kwargs[\"backend\"]\n",
" # matches, reason = info[\"match_generic\"](self.name, bound.args, bound.kwargs)\n",
" matches, reason = info[\"match_generic\"](self.name, args, kwargs)\n",
" if matches:\n",
" yield backend, \"generic\"\n",
" attempted_backends.add(backend)\n",
"\n",
" if self._default_backend not in attempted_backends:\n",
" yield self._default_backend, \"default\"\n",
"\n",
" def __call__(self, /, *args, backend=None, **kwargs):\n",
" if backend is not None:\n",
" return self.call_with_backend(backend, *args, **kwargs)\n",
"\n",
" priorities = get_backend_priority()\n",
" backends_that_shouldnt_run = []\n",
" reasons = []\n",
" for backend, how in self.match_backends(args, kwargs, priorities):\n",
" try:\n",
" # What do we think about `should_backend_run`?\n",
" reason = self.should_backend_run(backend, *args, **kwargs)\n",
" if reason is True:\n",
" return self.call_with_backend(backend, *args, **kwargs)\n",
" if isinstance(reason, str):\n",
" reasons.append(\n",
" f\"Backend '{backend}' matched by '{how}', \"\n",
" f\"but should not run because: {reason}\"\n",
" )\n",
" else:\n",
" reasons.append(\n",
" f\"Backend '{backend}' matched by '{how}', \"\n",
" \"but the backend chose not to run.\"\n",
" )\n",
" backends_that_shouldnt_run.append((backend, how))\n",
" except NotImplementedError as exc:\n",
" reasons.append(f\"Backend '{backend}' matched by '{how}', but it raised: {exc}\")\n",
" # Instead of failing right now, try backends that can, but shouldn't, run\n",
" for backend, how in backends_that_shouldnt_run:\n",
" try:\n",
" return self.call_with_backend(backend, *args, **kwargs)\n",
" except NotImplementedError as exc:\n",
" reasons.append(\n",
" f\"Backend '{backend}' matched by '{how}' \"\n",
" f\"(and it initially chose not to run), but it raised: {exc}\"\n",
" )\n",
"\n",
" if reasons:\n",
" explanation = \"Reasons for failures:\\n - \" + \"\\n - \".join(reasons)\n",
" else:\n",
" explanation = f\"None of the backends {self.backends} matched\"\n",
" raise NotImplementedError(\n",
" f\"No installed backend is able to run '{self.name}' with these inputs. \"\n",
" f\"{explanation}\"\n",
" )\n",
"\n",
" def call_with_backend(self, backend, /, *args, **kwargs):\n",
" print(\"Calling with backend\", backend)\n",
" return self.default_func(*args, **kwargs)\n",
"\n",
" def can_backend_run(self, backend, /, *args, **kwargs):\n",
" info = backend_info[backend]\n",
" if self.name not in info[\"functions\"]:\n",
" return False\n",
" dispatchable_types = self.get_dispatchable_types(*args, **kwargs)\n",
" if backend in func_to_backend_priority[self.name][dispatchable_types]:\n",
" return \"primary\"\n",
" if backend in func_to_backend_priority_secondary[self.name][dispatchable_types]:\n",
" return \"secondary\"\n",
" if backend in func_to_backend_match_ducktypes[self.name]:\n",
" dispatchable_objects = self.get_dispatchable_objects(*args, **kwargs)\n",
" match_ducktypes = func_to_backend_match_ducktypes[self.name][backend]\n",
" if match_ducktypes(self.name, dispatchable_objects):\n",
" return \"duck\"\n",
" if \"match_generic\" in info:\n",
" # bound = self.__signature__.bind(*args, **kwargs)\n",
" # bound.apply_defaults()\n",
" # del bound.kwargs[\"backend\"]\n",
" # matches, reason = info[\"match_generic\"](self.name, bound.args, bound.kwargs)\n",
" matches, reason = info[\"match_generic\"](self.name, args, kwargs)\n",
" if matches:\n",
" return \"generic\"\n",
" return False\n",
"\n",
" def should_backend_run(self, backend, /, *args, **kwargs):\n",
" # Is this function generally useful? Can we assume `can_backend_run` is true?\n",
" return True"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d7448e60-f414-44c5-8810-c4a9b4a63843",
"metadata": {},
"outputs": [],
"source": [
"@dispatchable({\"x\": 0})\n",
"def f1(x):\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f0f5b43e-85c5-49ca-a2dd-40452a741f2c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function __main__.f1(x, *, backend=None)>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f1"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8374c952-305e-4941-8f8e-7b64451e4f9d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calling with backend foo\n"
]
},
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f1(1)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7392cdb7-d128-41da-9c00-91d314630574",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calling with backend bar\n"
]
},
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f1(1.0)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d8aceaf7-caa5-44bf-ab5c-9cc252f2c6cf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calling with backend default\n"
]
},
{
"data": {
"text/plain": [
"'1'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f1(\"1\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "e69dbc64-2009-417b-92e6-4663c7157dc7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calling with backend foo\n"
]
},
{
"data": {
"text/plain": [
"np.int64(1)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f1(np.int64(1))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "a0e7f5ea-360f-42c7-969f-d8831b0d54c0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'int'> [('foo', 'secondary'), ('default', 'default')]\n",
"<class 'float'> [('bar', 'primary'), ('default', 'default')]\n",
"<class 'str'> [('default', 'default')]\n",
"<class 'numpy.int64'> [('foo', 'duck'), ('default', 'default')]\n"
]
}
],
"source": [
"kwargs = {}\n",
"for arg in [1, 1.0, \"1\", np.int64(1)]:\n",
" args = (arg,)\n",
" print(type(arg), list(f1.match_backends(args, kwargs)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment