Skip to content

Instantly share code, notes, and snippets.

@draincoder
Last active November 14, 2024 07:50
Show Gist options
  • Save draincoder/2499d6c6bf4299fdab33d73b225e07f7 to your computer and use it in GitHub Desktop.
Save draincoder/2499d6c6bf4299fdab33d73b225e07f7 to your computer and use it in GitHub Desktop.
A typed decorator with optional parameters for sync and async functions
import asyncio
from functools import wraps
from typing import Any, Callable, ParamSpec, TypeVar, cast, overload
F_Spec = ParamSpec("F_Spec")
F_Return = TypeVar("F_Return")
Func = Callable[F_Spec, F_Return]
def decorator_wrapper(arg1: Any, arg2: Any) -> Callable[[Func], Func]:
def real_decorator(func: Func) -> Func:
@wraps(func)
def sync_wrapper(*args: F_Spec.args, **kwargs: F_Spec.kwargs) -> F_Return: # type: ignore[type-var]
print(arg1, arg2, "sinc")
return cast(F_Return, func(*args, **kwargs))
@wraps(func)
async def async_wrapper(*args: F_Spec.args, **kwargs: F_Spec.kwargs) -> F_Return:
print(arg1, arg2, "async")
return cast(F_Return, await func(*args, **kwargs))
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
return real_decorator
@overload
def schrodinger_decorator(
call: Func,
*,
arg1: None = None,
arg2: None = None,
) -> Func: ...
@overload
def schrodinger_decorator(
call: None = None,
*,
arg1: Any = None,
arg2: Any = None,
) -> Callable[[Func], Func]: ...
def schrodinger_decorator(
call: Func | None = None,
*,
arg1: Any | None = None,
arg2: Any | None = None,
) -> Callable[[Func], Func] | Func:
wrap_decorator = decorator_wrapper(arg1, arg2)
if call is None:
return wrap_decorator
else:
return wrap_decorator(call)
@schrodinger_decorator(arg1=1, arg2=2)
async def async_check_with_args(x: int) -> int:
return x
@schrodinger_decorator(arg1=1, arg2=2)
def check_with_args(x: int) -> int:
return x
@schrodinger_decorator
async def async_check(x: int) -> int:
return x
@schrodinger_decorator
def check(x: int) -> int:
return x
@schrodinger_decorator
async def main() -> None:
x = 1
a = await async_check_with_args(x)
b = check_with_args(x)
c = await async_check(x)
d = check(x)
print(a + b + c + d)
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment