Last active
October 5, 2024 11:40
-
-
Save antonagestam/798a2dc69b6f7b0ea6dd393f02667b84 to your computer and use it in GitHub Desktop.
Pydantic partial/multi models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections.abc import Mapping | |
from typing import Any | |
from pydantic import BaseModel | |
from pydantic import root_validator | |
from pydantic.fields import ModelField | |
class PartialModel(BaseModel): ... | |
def _format_partial_model_field(field: ModelField) -> str: | |
return f"{field.name}: {field.type_.__qualname__}" | |
class MultiModel(BaseModel): | |
def __init_subclass__(cls, **kwargs: Any) -> None: | |
super().__init_subclass__(**kwargs) | |
top_level_field_names = frozenset(cls.__fields__.keys()) | |
child_field_mapping: dict[str, ModelField] = {} | |
for field_name, field in cls.__fields__.items(): | |
if not issubclass(field.type_, PartialModel): | |
continue | |
child_field_names = frozenset(field.type_.__fields__) | |
parent_child_overlap = top_level_field_names & child_field_names | |
if parent_child_overlap: | |
raise TypeError( | |
f"Illegal multi-model definition: fields of parent model " | |
f"{cls.__qualname__!r} has collisions in field {field_name!r} in " | |
f"the partial model {field.type_.__qualname__!r}. Overlapping " | |
f"field names: {set(parent_child_overlap)}." | |
) | |
for name in child_field_names: | |
if child_child_conflict := child_field_mapping.get(name): | |
raise TypeError( | |
f"Illegal multi-model definition: field {name!r} has multiple " | |
f"definitions in child models of {cls.__qualname__!r} it is " | |
f"defined in both fields " | |
f"{_format_partial_model_field(child_child_conflict)!r} " | |
f"and {_format_partial_model_field(field)!r}." | |
) | |
child_field_mapping[name] = field | |
@root_validator(pre=True) | |
@classmethod | |
def flatten_multi_model(cls, values: Mapping[str, object]) -> dict[str, object]: | |
top_level_field_names = frozenset(cls.__fields__.keys()) | |
values = dict(values) | |
for field_name, field in cls.__fields__.items(): | |
# Don't overwrite if the submodel is provided as nested. | |
if field_name in values: | |
continue | |
if not issubclass(field.type_, PartialModel): | |
continue | |
values[field_name] = { | |
key: value | |
for key, value in values.items() | |
if key not in top_level_field_names | |
} | |
return values | |
class A(PartialModel): | |
foo: int | |
class B(PartialModel): | |
bar: str | |
class C(MultiModel): | |
a: A | |
b: B | |
obj = C.parse_obj( | |
{ | |
"foo": 123, | |
"bar": "barf", | |
}, | |
) | |
assert obj == C( | |
a=A(foo=123), | |
b=B(bar="barf"), | |
) | |
# test raises for parent-child overlap | |
class P(PartialModel): | |
coll: int | |
class Q(PartialModel): | |
coll: int | |
try: | |
class R(MultiModel): | |
coll: int | |
p: P | |
except TypeError as e: | |
assert str(e) == ( | |
"Illegal multi-model definition: fields of parent model 'R' has collisions in " | |
"field 'p' in the partial model 'P'. Overlapping field names: {'coll'}." | |
) | |
else: | |
assert False, "did not raise" | |
# test raises for child-child overlap | |
try: | |
class S(MultiModel): | |
p: P | |
q: Q | |
except TypeError as e: | |
assert str(e) == ( | |
"Illegal multi-model definition: field 'coll' has multiple definitions in " | |
"child models of 'S' it is defined in both fields 'p: P' and 'q: Q'." | |
) | |
else: | |
assert False, "did not raise" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment