Created
September 30, 2024 21:42
-
-
Save tdiggelm/02c33079ac58a289c4c3f3e935c1120c to your computer and use it in GitHub Desktop.
Pydantic union root model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from types import UnionType | |
from typing import Annotated, Any, ClassVar, TypeVar, Union, get_args, get_origin | |
from pydantic import Discriminator, Field, RootModel | |
T = TypeVar("T") | |
__all__ = ["UnionRootModel"] | |
def get_type_variants(obj: Any) -> tuple[type[Any], ...]: | |
""" | |
Extracts type variants from the first Union encountered in a type annotation. | |
This function traverses annotated type expressions and returns the variants | |
of the first `Union` encountered. If the type is not a `Union`, it recursively | |
processes the first argument of the annotation (if any), and returns the type | |
itself if no `Union` is found. | |
Args: | |
obj (Any): The type annotation or object to extract variants from. | |
Returns: | |
tuple[type[Any], ...]: A tuple containing the types in the `Union`, or | |
a tuple with the original type if no `Union` is found. | |
Example: | |
For a type `Annotated[Final[Union[A, B]]]`, the function will return `(A, B)`. | |
""" | |
if get_origin(obj) is Union or isinstance(obj, UnionType): | |
return get_args(obj) | |
args = get_args(obj) | |
return get_type_variants(args[0]) if args else (obj,) | |
class UnionRootModel(RootModel[T]): | |
""" | |
A specialized root model that handles union types with discriminators. | |
This class extends Pydantic's `RootModel` to support union types by extracting | |
and validating types based on a discriminator field. It allows models to be | |
parameterized with union types, where the discriminator determines which | |
variant of the union should be instantiated during validation. | |
Attributes: | |
union_discriminator (ClassVar[str | Discriminator | None]): | |
The name of the discriminator field. It is used to determine which | |
union variant to instantiate based on the value of this field in the | |
input data. | |
union_variants (ClassVar[tuple[type[Any], ...]]): | |
A tuple containing the types of the union variants extracted from | |
the type parameter. | |
Methods: | |
__init_subclass__(cls, discriminator, **kwargs): | |
Initializes subclass with an optional discriminator and sets up | |
the discriminator field for the root model. | |
__class_getitem__(cls, params): | |
Handles generic type parameters and calculates the union variants | |
based on the provided type parameter. | |
""" | |
union_discriminator: ClassVar[str | Discriminator | None] = None | |
union_variants: ClassVar[tuple[type[Any], ...]] = () | |
def __init_subclass__( | |
cls, discriminator: str | Discriminator | None = None, **kwargs: Any | |
) -> "UnionRootModel": | |
""" | |
Initializes the subclass with an optional discriminator. | |
This method is called when a subclass is created. It sets the `union_discriminator` | |
field if provided, and validates that the discriminator is either a string or | |
a Pydantic `Discriminator`. It also updates the root field's annotation to | |
include the discriminator if it's not None. | |
Args: | |
discriminator (str | Discriminator | None): | |
The name of the discriminator field used to distinguish between | |
different union variants. If not provided, the class attribute | |
`union_discriminator` is used. | |
**kwargs (Any): Additional keyword arguments passed to the superclass. | |
Raises: | |
TypeError: If the provided discriminator is neither a string nor | |
a `Discriminator`. | |
""" | |
cls.union_discriminator = cls.union_discriminator or discriminator | |
if cls.union_discriminator and not isinstance( | |
cls.union_discriminator, (str, Discriminator) | |
): | |
raise TypeError( | |
f"Expected 'str' or 'Discriminator' for 'union_discriminator', " | |
f"got {type(cls.union_discriminator).__name__}" | |
) | |
root_annotation = cls.model_fields["root"].annotation | |
if cls.union_discriminator is not None: | |
cls.model_fields["root"].annotation = Annotated[ | |
root_annotation, Field(discriminator=cls.union_discriminator) | |
] | |
super().__init_subclass__(**kwargs) | |
def __class_getitem__(cls, params: Any) -> type[RootModel[T]]: | |
""" | |
Handles generic type instantiation for union types. | |
This method is called when a generic type is instantiated (e.g., `UnionRootModel[Foo | Bar]`). | |
It ensures that only a single type parameter is provided, extracts the union | |
variants using `get_type_variants`, and assigns them to the `union_variants` | |
attribute. | |
Args: | |
params (Any): The type parameter(s) provided when the class is instantiated. | |
Returns: | |
type[BaseModel]: A new class with the union variants extracted from the | |
provided type parameter. | |
Raises: | |
TypeError: If more than one type parameter is provided. | |
""" | |
params = (params,) if not isinstance(params, tuple) else params | |
if len(params) != 1: | |
raise TypeError(f"{cls.__name__} only supports a single type parameter") | |
new_cls = super().__class_getitem__(params) | |
new_cls.union_variants = get_type_variants(params[0]) | |
return new_cls | |
if __name__ == "__main__": | |
from typing import Literal | |
from pydantic import BaseModel | |
class Foo(BaseModel): | |
kind: Literal["foo"] = "foo" | |
class Bar(BaseModel): | |
kind: Literal["bar"] = "bar" | |
class FooOrBar(UnionRootModel[Foo | Bar]): | |
union_discriminator = "kind" | |
assert FooOrBar.model_validate({"kind": "foo"}).root == Foo(kind="foo") | |
print(FooOrBar.union_discriminator, FooOrBar.union_variants) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment