Skip to content

Instantly share code, notes, and snippets.

@mahenzon
Created January 26, 2025 06:58
Show Gist options
  • Save mahenzon/2961dd084f12e7a1ac7978977e26849a to your computer and use it in GitHub Desktop.
Save mahenzon/2961dd084f12e7a1ac7978977e26849a to your computer and use it in GitHub Desktop.
overload annotation in Python
import logging
# DEFAULT_LOG_FORMAT = "[%(asctime)s.%(msecs)03d] %(funcName)20s %(module)s:%(lineno)d %(levelname)-8s - %(message)s"
DEFAULT_LOG_FORMAT = (
"%(funcName)10s %(module)s:%(lineno)d %(levelname)-8s - %(message)s"
)
def configure_logging(level: int = logging.INFO) -> None:
logging.basicConfig(
level=level,
datefmt="%Y-%m-%d %H:%M:%S",
format=DEFAULT_LOG_FORMAT,
)
import logging
from dataclasses import dataclass
from functools import wraps
from typing import (
Callable,
ParamSpec,
TypeVar,
Union,
overload,
)
from common import configure_logging
log = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
@dataclass
class Counts:
count: int = 0
def inc(self) -> None:
self.count += 1
def dec(self) -> None:
self.count -= 1
class TraceMaker:
def __init__(
self,
sep: str = "*",
size: int = 1,
) -> None:
self.sep = sep
self.size = size
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
count = Counts()
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
count.inc()
log.info(
"%s call %r with args %r and kwargs %r",
self.sep * count.count * self.size,
func.__name__,
args,
kwargs,
)
result = func(*args, **kwargs)
log.info(
"%s call %r result: %r",
self.sep * count.count * self.size,
func.__name__,
result,
)
count.dec()
return result
return wrapper
@overload
def trace(
function: Callable[P, T],
) -> Callable[P, T]: ...
@overload
def trace(
*,
sep: str = "*",
size: int = 1,
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
def trace(
function: Callable[P, T] | None = None,
*,
sep: str = "*",
size: int = 1,
) -> Union[
Callable[[Callable[P, T]], Callable[P, T]],
Callable[P, T],
]:
trace_maker = TraceMaker(sep=sep, size=size)
if function is not None:
return trace_maker(function)
return trace_maker
@trace
def fac(n: int) -> int:
if n <= 2:
return n
return n * fac(n - 1)
@trace(sep="#", size=2)
def fib(n: int) -> int:
if n < 2:
return n
return fib(n - 1) + fib(n - 2)
def main() -> None:
configure_logging()
fac5 = fac(5)
log.info("fac(5) = %s", fac5)
log.warning("space")
fac7 = fac(7)
log.info("fac(7) = %s", fac7)
log.warning("space bigger")
fib3 = fib(3)
log.info("fib(3) = %s", fib3)
log.warning("space")
fib5 = fib(5)
log.info("fib(5) = %s", fib5)
if __name__ == "__main__":
main()
[tool.mypy]
strict = true
mypy==1.13.0
mypy-extensions==1.0.0
typing_extensions==4.12.2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment