Skip to content

Instantly share code, notes, and snippets.

@antonagestam
Last active October 5, 2024 11:40
Show Gist options
  • Save antonagestam/798a2dc69b6f7b0ea6dd393f02667b84 to your computer and use it in GitHub Desktop.
Save antonagestam/798a2dc69b6f7b0ea6dd393f02667b84 to your computer and use it in GitHub Desktop.
Pydantic partial/multi models
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel
from pydantic import root_validator
from pydantic.fields import ModelField
class PartialModel(BaseModel): ...
def _format_partial_model_field(field: ModelField) -> str:
return f"{field.name}: {field.type_.__qualname__}"
class MultiModel(BaseModel):
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
top_level_field_names = frozenset(cls.__fields__.keys())
child_field_mapping: dict[str, ModelField] = {}
for field_name, field in cls.__fields__.items():
if not issubclass(field.type_, PartialModel):
continue
child_field_names = frozenset(field.type_.__fields__)
parent_child_overlap = top_level_field_names & child_field_names
if parent_child_overlap:
raise TypeError(
f"Illegal multi-model definition: fields of parent model "
f"{cls.__qualname__!r} has collisions in field {field_name!r} in "
f"the partial model {field.type_.__qualname__!r}. Overlapping "
f"field names: {set(parent_child_overlap)}."
)
for name in child_field_names:
if child_child_conflict := child_field_mapping.get(name):
raise TypeError(
f"Illegal multi-model definition: field {name!r} has multiple "
f"definitions in child models of {cls.__qualname__!r} it is "
f"defined in both fields "
f"{_format_partial_model_field(child_child_conflict)!r} "
f"and {_format_partial_model_field(field)!r}."
)
child_field_mapping[name] = field
@root_validator(pre=True)
@classmethod
def flatten_multi_model(cls, values: Mapping[str, object]) -> dict[str, object]:
top_level_field_names = frozenset(cls.__fields__.keys())
values = dict(values)
for field_name, field in cls.__fields__.items():
# Don't overwrite if the submodel is provided as nested.
if field_name in values:
continue
if not issubclass(field.type_, PartialModel):
continue
values[field_name] = {
key: value
for key, value in values.items()
if key not in top_level_field_names
}
return values
class A(PartialModel):
foo: int
class B(PartialModel):
bar: str
class C(MultiModel):
a: A
b: B
obj = C.parse_obj(
{
"foo": 123,
"bar": "barf",
},
)
assert obj == C(
a=A(foo=123),
b=B(bar="barf"),
)
# test raises for parent-child overlap
class P(PartialModel):
coll: int
class Q(PartialModel):
coll: int
try:
class R(MultiModel):
coll: int
p: P
except TypeError as e:
assert str(e) == (
"Illegal multi-model definition: fields of parent model 'R' has collisions in "
"field 'p' in the partial model 'P'. Overlapping field names: {'coll'}."
)
else:
assert False, "did not raise"
# test raises for child-child overlap
try:
class S(MultiModel):
p: P
q: Q
except TypeError as e:
assert str(e) == (
"Illegal multi-model definition: field 'coll' has multiple definitions in "
"child models of 'S' it is defined in both fields 'p: P' and 'q: Q'."
)
else:
assert False, "did not raise"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment