Skip to content

Instantly share code, notes, and snippets.

@tdiggelm
Created September 30, 2024 21:42
Show Gist options
  • Save tdiggelm/02c33079ac58a289c4c3f3e935c1120c to your computer and use it in GitHub Desktop.
Save tdiggelm/02c33079ac58a289c4c3f3e935c1120c to your computer and use it in GitHub Desktop.
Pydantic union root model
from types import UnionType
from typing import Annotated, Any, ClassVar, TypeVar, Union, get_args, get_origin
from pydantic import Discriminator, Field, RootModel
T = TypeVar("T")
__all__ = ["UnionRootModel"]
def get_type_variants(obj: Any) -> tuple[type[Any], ...]:
"""
Extracts type variants from the first Union encountered in a type annotation.
This function traverses annotated type expressions and returns the variants
of the first `Union` encountered. If the type is not a `Union`, it recursively
processes the first argument of the annotation (if any), and returns the type
itself if no `Union` is found.
Args:
obj (Any): The type annotation or object to extract variants from.
Returns:
tuple[type[Any], ...]: A tuple containing the types in the `Union`, or
a tuple with the original type if no `Union` is found.
Example:
For a type `Annotated[Final[Union[A, B]]]`, the function will return `(A, B)`.
"""
if get_origin(obj) is Union or isinstance(obj, UnionType):
return get_args(obj)
args = get_args(obj)
return get_type_variants(args[0]) if args else (obj,)
class UnionRootModel(RootModel[T]):
"""
A specialized root model that handles union types with discriminators.
This class extends Pydantic's `RootModel` to support union types by extracting
and validating types based on a discriminator field. It allows models to be
parameterized with union types, where the discriminator determines which
variant of the union should be instantiated during validation.
Attributes:
union_discriminator (ClassVar[str | Discriminator | None]):
The name of the discriminator field. It is used to determine which
union variant to instantiate based on the value of this field in the
input data.
union_variants (ClassVar[tuple[type[Any], ...]]):
A tuple containing the types of the union variants extracted from
the type parameter.
Methods:
__init_subclass__(cls, discriminator, **kwargs):
Initializes subclass with an optional discriminator and sets up
the discriminator field for the root model.
__class_getitem__(cls, params):
Handles generic type parameters and calculates the union variants
based on the provided type parameter.
"""
union_discriminator: ClassVar[str | Discriminator | None] = None
union_variants: ClassVar[tuple[type[Any], ...]] = ()
def __init_subclass__(
cls, discriminator: str | Discriminator | None = None, **kwargs: Any
) -> "UnionRootModel":
"""
Initializes the subclass with an optional discriminator.
This method is called when a subclass is created. It sets the `union_discriminator`
field if provided, and validates that the discriminator is either a string or
a Pydantic `Discriminator`. It also updates the root field's annotation to
include the discriminator if it's not None.
Args:
discriminator (str | Discriminator | None):
The name of the discriminator field used to distinguish between
different union variants. If not provided, the class attribute
`union_discriminator` is used.
**kwargs (Any): Additional keyword arguments passed to the superclass.
Raises:
TypeError: If the provided discriminator is neither a string nor
a `Discriminator`.
"""
cls.union_discriminator = cls.union_discriminator or discriminator
if cls.union_discriminator and not isinstance(
cls.union_discriminator, (str, Discriminator)
):
raise TypeError(
f"Expected 'str' or 'Discriminator' for 'union_discriminator', "
f"got {type(cls.union_discriminator).__name__}"
)
root_annotation = cls.model_fields["root"].annotation
if cls.union_discriminator is not None:
cls.model_fields["root"].annotation = Annotated[
root_annotation, Field(discriminator=cls.union_discriminator)
]
super().__init_subclass__(**kwargs)
def __class_getitem__(cls, params: Any) -> type[RootModel[T]]:
"""
Handles generic type instantiation for union types.
This method is called when a generic type is instantiated (e.g., `UnionRootModel[Foo | Bar]`).
It ensures that only a single type parameter is provided, extracts the union
variants using `get_type_variants`, and assigns them to the `union_variants`
attribute.
Args:
params (Any): The type parameter(s) provided when the class is instantiated.
Returns:
type[BaseModel]: A new class with the union variants extracted from the
provided type parameter.
Raises:
TypeError: If more than one type parameter is provided.
"""
params = (params,) if not isinstance(params, tuple) else params
if len(params) != 1:
raise TypeError(f"{cls.__name__} only supports a single type parameter")
new_cls = super().__class_getitem__(params)
new_cls.union_variants = get_type_variants(params[0])
return new_cls
if __name__ == "__main__":
from typing import Literal
from pydantic import BaseModel
class Foo(BaseModel):
kind: Literal["foo"] = "foo"
class Bar(BaseModel):
kind: Literal["bar"] = "bar"
class FooOrBar(UnionRootModel[Foo | Bar]):
union_discriminator = "kind"
assert FooOrBar.model_validate({"kind": "foo"}).root == Foo(kind="foo")
print(FooOrBar.union_discriminator, FooOrBar.union_variants)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment