Skip to content

Instantly share code, notes, and snippets.

@martinkozle
Last active July 10, 2024 00:32
Show Gist options
  • Save martinkozle/37c7eef95f9bbb5ace8bc6e32f379673 to your computer and use it in GitHub Desktop.
Save martinkozle/37c7eef95f9bbb5ace8bc6e32f379673 to your computer and use it in GitHub Desktop.
Celery Pydantic Task PoC
import celery
from tasks import FooTask
app = celery.Celery(
"app",
broker="amqp://",
)
FooTask.register_for_app(app)
if __name__ == "__main__":
app.start()
import celery
from tasks import D, FooTask, M
app = celery.Celery(
"app",
broker="amqp://",
)
foo = FooTask.register_for_app(app)
t = (D(m=M(a="ad", b="bd"), p=42), 4.2)
m = M(a="a1", b="b1")
foo.delay(t, m)
import contextlib
import json
import warnings
from typing import Any, Generic, TypeVar
from pydantic import TypeAdapter
from pydantic_core import to_json
TypeAdapterType = TypeVar("TypeAdapterType")
class PydanticCelerySerializer(Generic[TypeAdapterType]):
def __init__(self, type_adapter: TypeAdapter[TypeAdapterType]) -> None:
self.type_adapter = type_adapter
def dump(self, object: Any) -> str: # noqa: ANN401
with contextlib.suppress(TypeError), warnings.catch_warnings(action="ignore"):
return self.type_adapter.dump_json(object).decode()
# We have a message to the client
if isinstance(object, tuple):
args, kwargs, config = object
return json.dumps((
[self.type_adapter.dump_json(arg).decode() for arg in args],
{
key: self.type_adapter.dump_json(value).decode()
for key, value in kwargs.items()
},
config,
))
# We have a result from the worker
object["result"] = self.type_adapter.dump_json(object["result"]).decode()
return json.dumps(object)
def load(self, message: str) -> Any: # noqa: ANN401
parsed_message = json.loads(message)
if not isinstance(parsed_message, list) or len(parsed_message) != 3:
return parsed_message
args, kwargs, config = parsed_message
return (
[self.type_adapter.validate_python(arg) for arg in args],
{
key: self.type_adapter.validate_python(value)
for key, value in kwargs.items()
},
config,
)
import abc
from typing import Any, Generic, ParamSpec, Self, TypeVar, Union, cast
from celery import Celery, Task
from kombu.serialization import register
from pydantic import BaseModel, TypeAdapter
from pydantic_serializer import PydanticCelerySerializer
Task.__class_getitem__ = classmethod( # type: ignore[attr-defined]
lambda cls, *args, **kwargs: cls
)
TaskParams = ParamSpec("TaskParams")
TaskOutput = TypeVar("TaskOutput")
class BasePydanticTask(
Generic[TaskParams, TaskOutput], Task[TaskParams, TaskOutput], abc.ABC
):
def __init__(self) -> None:
self.serializer = f"pydantic_serializer_{self.__class__.__name__}"
type_adapter = TypeAdapter(Union[*self.task_param_types(), self.task_output()])
pcs = PydanticCelerySerializer(type_adapter)
register(
self.serializer,
encoder=pcs.dump,
decoder=pcs.load,
content_type="application/json",
content_encoding="utf-8",
)
@classmethod
def task_param_types(cls) -> tuple[type, ...]:
return cast(tuple[type, ...], cls.__orig_bases__[0].__args__[0])
@classmethod
def task_output(cls) -> type:
return cast(type, cls.__orig_bases__[0].__args__[1])
@classmethod
def register_for_app(cls, app: Celery, **options: Any) -> Self: # noqa: ANN401
return cast(Self, app.register_task(cls, **options))
class M(BaseModel):
a: str
b: str
class D(BaseModel):
m: M
p: int
class FooTask(BasePydanticTask[[tuple[D, float], M], list[D]]):
def run(self, t: tuple[D, float], m: M) -> list[D]:
return [t[0], D(m=m, p=int(t[1]))]
@martinkozle
Copy link
Author

martinkozle commented Feb 18, 2024

For returning results

There seems to be no way to override the app result_serializer inside the task class. Or I just haven't found a way.

Current workaround that I found that works is defining a separate serializer just for the result, and setting the result_serializer on the worker Celery instance to that:

from pydantic_core import to_json

register(
    "pydantic_result_serializer",
    encoder=lambda x: to_json(x).decode(),
    decoder=None,
    content_type="application/json",
    content_encoding="utf8",
)

app = celery.Celery(
    "app",
    broker="amqp://",
    backend="rpc://",
    result_serializer="pydantic_result_serializer",
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment