Created
January 26, 2025 06:58
-
-
Save mahenzon/2961dd084f12e7a1ac7978977e26849a to your computer and use it in GitHub Desktop.
overload annotation in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
# DEFAULT_LOG_FORMAT = "[%(asctime)s.%(msecs)03d] %(funcName)20s %(module)s:%(lineno)d %(levelname)-8s - %(message)s" | |
DEFAULT_LOG_FORMAT = ( | |
"%(funcName)10s %(module)s:%(lineno)d %(levelname)-8s - %(message)s" | |
) | |
def configure_logging(level: int = logging.INFO) -> None: | |
logging.basicConfig( | |
level=level, | |
datefmt="%Y-%m-%d %H:%M:%S", | |
format=DEFAULT_LOG_FORMAT, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from dataclasses import dataclass | |
from functools import wraps | |
from typing import ( | |
Callable, | |
ParamSpec, | |
TypeVar, | |
Union, | |
overload, | |
) | |
from common import configure_logging | |
log = logging.getLogger(__name__) | |
P = ParamSpec("P") | |
T = TypeVar("T") | |
@dataclass | |
class Counts: | |
count: int = 0 | |
def inc(self) -> None: | |
self.count += 1 | |
def dec(self) -> None: | |
self.count -= 1 | |
class TraceMaker: | |
def __init__( | |
self, | |
sep: str = "*", | |
size: int = 1, | |
) -> None: | |
self.sep = sep | |
self.size = size | |
def __call__(self, func: Callable[P, T]) -> Callable[P, T]: | |
count = Counts() | |
@wraps(func) | |
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: | |
count.inc() | |
log.info( | |
"%s call %r with args %r and kwargs %r", | |
self.sep * count.count * self.size, | |
func.__name__, | |
args, | |
kwargs, | |
) | |
result = func(*args, **kwargs) | |
log.info( | |
"%s call %r result: %r", | |
self.sep * count.count * self.size, | |
func.__name__, | |
result, | |
) | |
count.dec() | |
return result | |
return wrapper | |
@overload | |
def trace( | |
function: Callable[P, T], | |
) -> Callable[P, T]: ... | |
@overload | |
def trace( | |
*, | |
sep: str = "*", | |
size: int = 1, | |
) -> Callable[[Callable[P, T]], Callable[P, T]]: ... | |
def trace( | |
function: Callable[P, T] | None = None, | |
*, | |
sep: str = "*", | |
size: int = 1, | |
) -> Union[ | |
Callable[[Callable[P, T]], Callable[P, T]], | |
Callable[P, T], | |
]: | |
trace_maker = TraceMaker(sep=sep, size=size) | |
if function is not None: | |
return trace_maker(function) | |
return trace_maker | |
@trace | |
def fac(n: int) -> int: | |
if n <= 2: | |
return n | |
return n * fac(n - 1) | |
@trace(sep="#", size=2) | |
def fib(n: int) -> int: | |
if n < 2: | |
return n | |
return fib(n - 1) + fib(n - 2) | |
def main() -> None: | |
configure_logging() | |
fac5 = fac(5) | |
log.info("fac(5) = %s", fac5) | |
log.warning("space") | |
fac7 = fac(7) | |
log.info("fac(7) = %s", fac7) | |
log.warning("space bigger") | |
fib3 = fib(3) | |
log.info("fib(3) = %s", fib3) | |
log.warning("space") | |
fib5 = fib(5) | |
log.info("fib(5) = %s", fib5) | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[tool.mypy] | |
strict = true |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mypy==1.13.0 | |
mypy-extensions==1.0.0 | |
typing_extensions==4.12.2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment