Last active
May 13, 2024 15:23
-
-
Save LiutongZhou/1d9f8b1231acd8173b30db3e322c67cc to your computer and use it in GitHub Desktop.
Universal Decorator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Universal Decorator | |
Universal decorators can decorate functions, classes, bound methods | |
(class method / instance method) referenced outside of class definition | |
and descriptors (class methods, static methods) defined inside class definition. | |
""" | |
from __future__ import annotations | |
import inspect | |
import logging | |
import os | |
from abc import ABC, abstractmethod | |
from contextlib import redirect_stdout | |
from dataclasses import dataclass, field | |
from functools import update_wrapper | |
from typing import Any, Optional, Type, TypeVar | |
T = TypeVar("T") # generic type | |
__all__ = ["UniversalDecoratorBase"] | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
handler = logging.StreamHandler() | |
handler.setFormatter(logging.Formatter("[Line: %(lineno)d] - %(message)s")) | |
logger.addHandler(handler) | |
@dataclass | |
class UniversalDecoratorBase(ABC): | |
"""Base Class for Universal Decorators | |
Inherit this base class to create universal decorators. | |
Unlike ordinary decorators that normally only decorate callables, | |
universal decorators can decorate functions, classes, bound methods | |
(class method / instance method) referenced outside of class definition | |
and descriptors (class methods, static methods, instance methods etc.) | |
defined inside class definition. | |
When the decorated object is called with (*args, **kwargs), the __get__ and __call__ methods | |
will be called properly based on the type of the original object to decorate. | |
If to_decorate is a descriptor (classmethod, staticmethod), instance method defined inside class | |
definition, or Class.method wrapped outside of class definition, the __get__ method will be | |
called. | |
If to_decorate is a function, a class, or instance.bound_method wrapped outside of class | |
definition, the __call__ method will be called. | |
Parameters | |
---------- | |
to_decorate : Any | |
the original object to decorate. A Callable or descriptor. | |
Examples | |
-------- | |
>>> from time import sleep, time | |
>>> class timeit(UniversalDecoratorBase): | |
... def __call__(self, *args, **kwargs): | |
... tic = time() | |
... result = self.to_decorate(*args, **kwargs) | |
... toc = time() | |
... print(f"Time taken: {toc - tic:0.1f} seconds") | |
... return result | |
>>> class SuperNeuralNetwork: | |
... @timeit | |
... @classmethod | |
... def from_pretrained(cls, path): | |
... print(f"Loading model from {path}") | |
... sleep(0.1) | |
... return cls() | |
... @timeit | |
... def forward(self, x): | |
... print("Forward pass") | |
... sleep(0.1) | |
... return -x | |
>>> model = SuperNeuralNetwork.from_pretrained("path/to/model") | |
Loading model from path/to/model | |
Time taken: 0.1 seconds | |
>>> model.forward(1) | |
Forward pass | |
Time taken: 0.1 seconds | |
-1 | |
""" | |
to_decorate: Any | |
# internal use only, to keep track of the object that called the descriptor | |
_bound_to: Any = None | |
# internal use only, to keep track of the name of the decorated object, set by __set_name__ | |
_name: Optional[str] = field(default=None, init=False) | |
def __post_init__(self): | |
if callable(self.to_decorate): | |
update_wrapper(self, self.to_decorate) | |
if self._name is not None: | |
self.__name__ = self._name | |
def __getattr__(self, __name: str) -> Any: | |
"""Return getattr(self, name).""" | |
return getattr(self.to_decorate, __name) | |
def __get__(self, instance: Optional[T], owner: Type[T]) -> UniversalDecoratorBase: | |
"""Routing the . operator to __get__ | |
Original to_decorate is a descriptor (staticmethod, classmethod), method function | |
defined inside class, or a bound method (instance method / class method) wrapped | |
outside of class definition. | |
""" | |
_to_decorate = self.to_decorate | |
# if the original to_decorate is a descriptor, get the actual callable | |
while inspect.ismethoddescriptor(_to_decorate): | |
_to_decorate = _to_decorate.__get__(instance, owner) | |
_signature = inspect.signature(_to_decorate) | |
# if is a method, bind the method to the instance or class | |
if "self" in _signature.parameters or "cls" in _signature.parameters: | |
_to_decorate = _to_decorate.__get__(instance, owner) | |
decorated_callable = type(self)( | |
to_decorate=_to_decorate, _bound_to=instance or owner | |
) | |
if self._name: | |
decorated_callable.__name__ = decorated_callable._name = self._name | |
return decorated_callable | |
def __set__(self, instance, value): | |
"""Support changing the original to_decorate via object.decorated = value""" | |
self.to_decorate = value | |
def __set_name__(self, owner, name: str): | |
"""Called on class definition when the descriptor is assigned to a class attribute.""" | |
self.__name__ = self._name = name | |
def __delete__(self, instance): | |
"""Delete the descriptor from the instance.""" | |
assert self._name is not None | |
delattr(type(instance), self._name) | |
@abstractmethod | |
def __call__(self, *args, **kwargs): | |
"""Implement your decoration logic here. | |
to_decorate is a function (static method included), class, | |
or bounded method (instance / class method). | |
""" | |
to_decorate = self.to_decorate | |
# _called_from is None if decorating a function, class, | |
# or instance.bound_method outside of class definition | |
_called_from = self._bound_to | |
if inspect.ismethod( | |
to_decorate # bound method: either a class method or instance method | |
): | |
if inspect.isclass(to_decorate.__self__): | |
logger.debug( | |
f"decorating a class method: {to_decorate=} " | |
f"of class: {(class_method_owner := to_decorate.__self__)=}" | |
) | |
if _called_from is class_method_owner: | |
logger.debug("class method called from Class") | |
elif isinstance(_called_from, class_method_owner): | |
logger.debug("class method called from instance") | |
elif _called_from is None: | |
logger.debug( | |
"decorating an instance.class_method wrapped outside of class definition " | |
"and called from an instance" | |
) | |
else: | |
raise NotImplementedError | |
return to_decorate(*args, **kwargs) | |
else: | |
logger.debug( | |
f"decorating an instance method: {to_decorate=} " | |
f"bound to instance: {(instance_method_owner := to_decorate.__self__)=}" | |
) | |
if _called_from is instance_method_owner: | |
logger.debug( | |
"decorating an instance method called from an instance" | |
) | |
elif _called_from is None: | |
logger.debug( | |
"decorating an instance.method referenced outside of class definition " | |
"and called from an instance" | |
) | |
else: | |
raise NotImplementedError | |
return to_decorate(*args, **kwargs) | |
elif inspect.isclass(to_decorate): | |
logger.debug(f"decorating a class: {to_decorate=}") | |
return to_decorate(*args, **kwargs) | |
elif inspect.isfunction(to_decorate): | |
if _called_from is None: | |
logger.debug(f"decorating a function {to_decorate=}") | |
return to_decorate(*args, **kwargs) | |
else: | |
logger.debug( | |
f"decorating a static method " | |
f"or Class.instance_method called from {_called_from=}" | |
) | |
return to_decorate(*args, **kwargs) | |
else: | |
raise NotImplementedError | |
class UniversalContextDecorator(UniversalDecoratorBase): | |
"""A base class or mixin that enables context managers to work as universal decorators | |
Examples | |
-------- | |
see the `mute` class below | |
""" | |
def _recreate_cm(self): | |
"""Return a recreated instance of self. | |
Allows an otherwise one-shot context manager like | |
_GeneratorContextManager to support use as | |
a decorator via implicit recreation. | |
This is a private interface just for _GeneratorContextManager. | |
See issue #11647 for details. | |
""" | |
return self | |
def __call__(self, *args, **kwargs): | |
with self._recreate_cm(): | |
return self.to_decorate(*args, **kwargs) | |
class mute(UniversalContextDecorator): | |
"""A universal decorator that Mute the standard output temporarily. | |
Examples | |
-------- | |
>>> class SomeClass: | |
... @mute | |
... def some_process(self): | |
... print("This message won't be printed.") | |
>>> SomeClass().some_process() | |
>>> with mute(None): | |
... print("This message won't be printed either") | |
""" | |
def __enter__(self): | |
self._devnull = open(os.devnull, "w") | |
self._redirect_stdout_context = redirect_stdout(self._devnull) | |
self._redirect_stdout_context.__enter__() | |
def __exit__(self, exc_type, exc_val, exc_traceback): | |
self._redirect_stdout_context.__exit__(exc_type, exc_val, exc_traceback) | |
self._devnull.close() | |
class TestUniversalDecorator: | |
"""pytest Test Suite for Universal Decorator""" | |
@pytest.fixture(autouse=True, scope="class") | |
def custom_decorator(self): | |
class custom_decorator(UniversalDecoratorBase): | |
def __call__(self, *args, **kwargs): | |
return super().__call__(*args, **kwargs) | |
return custom_decorator | |
@pytest.fixture(autouse=True, scope="class") | |
def MyClass(self, custom_decorator): | |
class MyClass: | |
@custom_decorator | |
def instance_method(self, x) -> int: | |
"""My instance method""" | |
logger.debug(self) | |
logger.debug("instance_method called") | |
return x | |
def instance_method2(self, x): | |
logger.debug("instance_method2 called") | |
return x | |
def instance_method3(self, x): | |
logger.debug("instance_method3 called") | |
return x | |
@custom_decorator | |
@classmethod | |
def class_method(cls, x): | |
logger.debug("class_method called") | |
return cls.__name__ + f"-{x=}" | |
@classmethod | |
def class_method2(cls, x): | |
logger.debug(f"class_method2 called {cls=}") | |
return cls.__name__ + f"-{x=}" | |
@classmethod | |
def class_method3(cls, x): | |
logger.debug(f"class_method3 called {cls=}") | |
return cls.__name__ + f"-{x=}" | |
@custom_decorator | |
@staticmethod | |
def static_method(x): | |
logger.debug("static_method called") | |
return x | |
@staticmethod | |
def static_method2(x): | |
logger.debug("static_method2 called") | |
return x | |
@staticmethod | |
def static_method3(x): | |
logger.debug("static_method3 called") | |
return x | |
@custom_decorator | |
@classmethod | |
def class_method_for_deletion(cls): | |
return "Delete me" | |
return MyClass | |
def test_decorating_class(self): | |
class custom_decorator(UniversalDecoratorBase): | |
def __call__(self, *args, **kwargs): | |
def __init__(self, *args, **kwargs): | |
self.args = args | |
for key, value in kwargs.items(): | |
setattr(self, key, value) | |
self.to_decorate.__init__ = __init__ | |
return super().__call__(*args, **kwargs) | |
@custom_decorator | |
class MyClass: | |
"""Empty Class""" | |
my_object = MyClass(1, 2, 3, foo="bar") # type: ignore[call-arg] | |
assert my_object.args == (1, 2, 3) # type: ignore[attr-defined] | |
assert my_object.foo == "bar" # type: ignore[attr-defined] | |
assert my_object.__class__.__name__ == "MyClass" | |
assert MyClass.__name__ == "MyClass" | |
print(MyClass) | |
def test_decorating_function(self): | |
class custom_decorator(UniversalDecoratorBase): | |
def __call__(self, *args, **kwargs): | |
first, *rest = args | |
args = (first + 1, *rest) | |
return super().__call__(*args, **kwargs) | |
@custom_decorator | |
def my_function(x): | |
return x | |
assert my_function(1) == 2 | |
@pytest.mark.usefixtures("MyClass") | |
def test_decorating_instance_method_within_class_definition(self, MyClass): | |
logger.info("Testing decorating class.instance_method inside class definition") | |
logger.info(f"{MyClass.instance_method=}") | |
my_object = MyClass() | |
logger.info(f"{my_object.instance_method=}") | |
assert my_object.instance_method(x=1) == 1 | |
assert my_object.instance_method(x=2) == 2 | |
assert MyClass.instance_method(my_object, x=3) == 3 | |
def test_deleting_decorated_method(self, MyClass): | |
my_object = MyClass() | |
del my_object.instance_method | |
assert not hasattr(my_object, "instance_method") | |
assert hasattr(MyClass, "class_method_for_deletion") | |
assert MyClass.class_method_for_deletion() == "Delete me" | |
del MyClass.class_method_for_deletion | |
assert not hasattr(MyClass, "class_method_for_deletion") | |
@pytest.mark.usefixtures("MyClass", "custom_decorator") | |
def test_decorating_instance_method_outside_class_definition( | |
self, custom_decorator, MyClass | |
): | |
logger.info("Testing decorating class.instance_method outside class definition") | |
MyClass.instance_method2 = custom_decorator(MyClass.instance_method2) | |
logger.info(MyClass.instance_method2) | |
my_object = MyClass() | |
assert my_object.instance_method2(x=1) == 1 | |
assert my_object.instance_method2(x=2) == 2 | |
assert MyClass.instance_method2(my_object, x=3) == 3 | |
logger.info( | |
"Testing decorating object.instance_method outside class definition" | |
) | |
my_object.instance_method3 = custom_decorator(my_object.instance_method3) | |
assert my_object.instance_method3(x=1) == 1 | |
assert my_object.instance_method3(x=2) == 2 | |
assert MyClass.instance_method3(my_object, x=3) == 3 | |
@pytest.mark.usefixtures("MyClass") | |
def test_decorating_class_method_within_class_definition(self, MyClass): | |
logger.debug(f"{MyClass.class_method=}") | |
logger.debug(f"{MyClass().class_method=}") | |
assert MyClass.class_method(x=1) == "MyClass-x=1" | |
assert MyClass().class_method(x=1) == "MyClass-x=1" | |
@pytest.mark.usefixtures("MyClass", "custom_decorator") | |
def test_decorating_class_method_outside_class_definition( | |
self, custom_decorator, MyClass | |
): | |
MyClass.class_method2 = custom_decorator(MyClass.class_method2) | |
assert MyClass.class_method2(x=1) == "MyClass-x=1" | |
assert MyClass().class_method2(x=1) == "MyClass-x=1" | |
my_class = MyClass() | |
my_class.class_method3 = custom_decorator(my_class.class_method3) | |
assert MyClass.class_method3(x=1) == "MyClass-x=1" | |
assert my_class.class_method3(x=1) == "MyClass-x=1" | |
@pytest.mark.usefixtures("MyClass") | |
def test_decorating_static_method_within_class_definition(self, MyClass): | |
logger.debug(f"{MyClass.static_method=}") | |
logger.debug(f"{MyClass().static_method=}") | |
assert MyClass.static_method(x=1) == 1 | |
assert MyClass().static_method(x=1) == 1 | |
def test_decorating_static_method_outside_class_definition( | |
self, custom_decorator, MyClass | |
): | |
MyClass.static_method2 = custom_decorator(MyClass.static_method2) | |
assert MyClass.static_method2(x=1) == 1 | |
assert MyClass().static_method2(x=1) == 1 | |
my_class = MyClass() | |
my_class.static_method3 = custom_decorator(my_class.static_method3) | |
assert MyClass.static_method3(x=1) == 1 | |
assert my_class.static_method3(x=1) == 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment