Last active
July 10, 2024 00:32
-
-
Save martinkozle/37c7eef95f9bbb5ace8bc6e32f379673 to your computer and use it in GitHub Desktop.
Celery Pydantic Task PoC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import celery | |
from tasks import FooTask | |
app = celery.Celery( | |
"app", | |
broker="amqp://", | |
) | |
FooTask.register_for_app(app) | |
if __name__ == "__main__": | |
app.start() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import celery | |
from tasks import D, FooTask, M | |
app = celery.Celery( | |
"app", | |
broker="amqp://", | |
) | |
foo = FooTask.register_for_app(app) | |
t = (D(m=M(a="ad", b="bd"), p=42), 4.2) | |
m = M(a="a1", b="b1") | |
foo.delay(t, m) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
import json | |
import warnings | |
from typing import Any, Generic, TypeVar | |
from pydantic import TypeAdapter | |
from pydantic_core import to_json | |
TypeAdapterType = TypeVar("TypeAdapterType") | |
class PydanticCelerySerializer(Generic[TypeAdapterType]): | |
def __init__(self, type_adapter: TypeAdapter[TypeAdapterType]) -> None: | |
self.type_adapter = type_adapter | |
def dump(self, object: Any) -> str: # noqa: ANN401 | |
with contextlib.suppress(TypeError), warnings.catch_warnings(action="ignore"): | |
return self.type_adapter.dump_json(object).decode() | |
# We have a message to the client | |
if isinstance(object, tuple): | |
args, kwargs, config = object | |
return json.dumps(( | |
[self.type_adapter.dump_json(arg).decode() for arg in args], | |
{ | |
key: self.type_adapter.dump_json(value).decode() | |
for key, value in kwargs.items() | |
}, | |
config, | |
)) | |
# We have a result from the worker | |
object["result"] = self.type_adapter.dump_json(object["result"]).decode() | |
return json.dumps(object) | |
def load(self, message: str) -> Any: # noqa: ANN401 | |
parsed_message = json.loads(message) | |
if not isinstance(parsed_message, list) or len(parsed_message) != 3: | |
return parsed_message | |
args, kwargs, config = parsed_message | |
return ( | |
[self.type_adapter.validate_python(arg) for arg in args], | |
{ | |
key: self.type_adapter.validate_python(value) | |
for key, value in kwargs.items() | |
}, | |
config, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import abc | |
from typing import Any, Generic, ParamSpec, Self, TypeVar, Union, cast | |
from celery import Celery, Task | |
from kombu.serialization import register | |
from pydantic import BaseModel, TypeAdapter | |
from pydantic_serializer import PydanticCelerySerializer | |
Task.__class_getitem__ = classmethod( # type: ignore[attr-defined] | |
lambda cls, *args, **kwargs: cls | |
) | |
TaskParams = ParamSpec("TaskParams") | |
TaskOutput = TypeVar("TaskOutput") | |
class BasePydanticTask( | |
Generic[TaskParams, TaskOutput], Task[TaskParams, TaskOutput], abc.ABC | |
): | |
def __init__(self) -> None: | |
self.serializer = f"pydantic_serializer_{self.__class__.__name__}" | |
type_adapter = TypeAdapter(Union[*self.task_param_types(), self.task_output()]) | |
pcs = PydanticCelerySerializer(type_adapter) | |
register( | |
self.serializer, | |
encoder=pcs.dump, | |
decoder=pcs.load, | |
content_type="application/json", | |
content_encoding="utf-8", | |
) | |
@classmethod | |
def task_param_types(cls) -> tuple[type, ...]: | |
return cast(tuple[type, ...], cls.__orig_bases__[0].__args__[0]) | |
@classmethod | |
def task_output(cls) -> type: | |
return cast(type, cls.__orig_bases__[0].__args__[1]) | |
@classmethod | |
def register_for_app(cls, app: Celery, **options: Any) -> Self: # noqa: ANN401 | |
return cast(Self, app.register_task(cls, **options)) | |
class M(BaseModel): | |
a: str | |
b: str | |
class D(BaseModel): | |
m: M | |
p: int | |
class FooTask(BasePydanticTask[[tuple[D, float], M], list[D]]): | |
def run(self, t: tuple[D, float], m: M) -> list[D]: | |
return [t[0], D(m=m, p=int(t[1]))] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
For returning results
There seems to be no way to override the app
result_serializer
inside the task class. Or I just haven't found a way.Current workaround that I found that works is defining a separate serializer just for the result, and setting the
result_serializer
on the worker Celery instance to that: