Last active
November 26, 2024 11:17
-
-
Save gutzbenj/f8be9fa1dd254b0c4ada582ffe0ed3d3 to your computer and use it in GitHub Desktop.
A pydantic representation of Googles Vertex AI Search for Retail product data model (not included: collection)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Below you find a pydantic model that closely represents Googles product data model for their Vertex AI Search for Retail. | |
Their model is closely described | |
- in their restapi overview: https://cloud.google.com/retail/docs/reference/rest/v2/projects.locations.catalogs.branches.products | |
- as proto message in code: https://github.com/googleapis/google-cloud-python/blob/main/packages/google-cloud-retail/google/cloud/retail_v2/types/product.py | |
The model so far only supports product types "variant" and "primary" - which is a group of variants e.g. one productId with many skuIds. | |
Collections are currently not supported! | |
After validation product data has to be dumped `by_alias`: | |
product_data.model_dump_json(by_alias=True) | |
If you want to import data to the Search for Retail you need to provide it as multiline JSON: | |
"\n".join(product.model_dump_json(by_alias=True) for product in product_data) | |
Missing pieces: | |
- The types of currency and langugage code are not fully correct - you need to use actually existing literals there | |
- expiry_time and ttl need to be fixed, you should only set one of them and ttl should be a duration string. | |
""" | |
from typing import Self, Literal, constr | |
import datetime as dt | |
from pydantic import BaseModel, Field, model_valiator, field_validator, ValidationInfo | |
class PriceInterval(BaseModel): | |
minimum: float | |
maximum: float | |
@model_validator(mode="after") | |
def validate(self) -> Self: | |
if self.minimum > self.maximum: | |
raise ValueError(f"minimum ({self.minimum}) must be less or equal to maximum ({self.maximum})") | |
return self | |
class PriceIntervalExclusiveMinimum(BaseModel): | |
exclusive_minimum: float | |
maximum: float | |
@model_validator(mode="after") | |
def validate(self) -> Self: | |
if self.exclusive_minimum >= self.maximum: | |
raise ValueError(f"exclusive_minimum ({self.exclusive_minimum}) must be less than maximum ({self.maximum})") | |
return self | |
class PriceIntervalExclusiveMaximum(BaseModel): | |
minimum: float | |
exclusive_maximum: float | |
@model_validator(mode="after") | |
def validate(self) -> Self: | |
if self.minimum >= self.exclusive_maximum: | |
raise ValueError(f"minimum ({self.minimum}) must be less than exclusive_maximum ({self.exclusive_maximum})") | |
return self | |
class PriceIntervalExclusive(BaseModel): | |
exclusive_minimum: float | |
exclusive_maximum: float | |
@model_validator(mode="after") | |
def validate(self) -> Self: | |
if self.exclusive_minimum >= self.exclusive_maximum: | |
raise ValueError( | |
f"exclusive_minimum ({self.exclusive_minimum}) must be less than " | |
f"exclusive_maximum ({self.exclusive_maximum})" | |
) | |
return self | |
class PriceRange(BaseModel): | |
price: PriceInterval | PriceIntervalExclusiveMinimum | PriceIntervalExclusiveMaximum | PriceIntervalExclusive | |
original_price: ( | |
PriceInterval | PriceIntervalExclusiveMinimum | PriceIntervalExclusiveMaximum | PriceIntervalExclusive | |
) | |
class PriceInfo(BaseModel): | |
currency_code: str | |
price: float | |
original_price: float = 0.0 | |
@field_validator("original_price", mode="after") | |
@classmethod | |
def validate_original_price(cls, v, info: ValidationInfo): | |
if v == 0.0: | |
return info.data["price"] | |
if v < info.data["price"]: | |
raise ValueError("original_price must be greater than price") | |
return v | |
cost: float | |
price_effective_time: dt.datetime | None = None | |
price_expire_time: dt.datetime | None = None | |
price_range: PriceRange | |
class Rating(BaseModel): | |
rating_count: int | |
average_rating: float = Field(ge=1, le=5) | |
rating_histogram: list[int] = Field(min_length=5, max_length=5) | |
class FulfillmentInfo(BaseModel): | |
type_: Literal[ | |
"pickup-in-store", | |
"ship-to-store", | |
"same-day-delivery", | |
"next-day-delivery", | |
"custom-type-1", | |
"custom-type-2", | |
"custom-type-3", | |
"custom-type-4", | |
"custom-type-5", | |
] | |
place_ids: list[str] | |
class Image(BaseModel): | |
uri: str | |
height: int = Field(gt=0) | |
width: int = Field(gt=0) | |
class Audience(BaseModel): | |
genders: set[Literal["male", "female", "unisex"]] | |
age_groups: set[Literal["newborn", "infant", "toddler", "kids", "adult"]] | |
class ColorInfo(BaseModel): | |
color_families: set[ | |
Literal[ | |
"red", | |
"pink", | |
"orange", | |
"yellow", | |
"purple", | |
"green", | |
"cyan", | |
"blue", | |
"brown", | |
"white", | |
"gray", | |
"black", | |
"mixed", | |
] | |
] = Field(max_length=5) | |
colors: set[str] | |
class CustomAttributeText(BaseModel): | |
text: set[constr(min_length=1, max_length=256)] = Field(max_length=400) | |
class CustomAttributeNumbers(BaseModel): | |
numbers: set[float] = Field(max_length=400) | |
class Promotion(BaseModel): | |
promotion_id: str | |
class _Product(BaseModel): | |
publish_time: dt.datetime | |
available_time: dt.datetime | |
expire_time: dt.datetime | |
@field_validator("expire_time", mode="after") | |
@classmethod | |
def validate_expire_time(cls, v, info: ValidationInfo): | |
if v < info.data["available_time"]: | |
raise ValueError("expire_time must be after available_time") | |
if v < info.data["publish_time"]: | |
raise ValueError("expire_time must be after publish_time") | |
return v | |
ttl: int | None = None | |
id: constr(max_length=128) | |
type_: Literal["TYPE_UNSPECIFIED", "PRIMARY", "VARIANT", "COLLECTION"] | int = Field( | |
default="TYPE_UNSPECIFIED", serialization_alias="type" | |
) | |
@field_validator("type_", mode="before") | |
@classmethod | |
def validate_type_(cls, v): | |
type_map = { | |
"TYPE_UNSPECIFIED": 0, | |
"PRIMARY": 1, | |
"VARIANT": 2, | |
"COLLECTION": 3, | |
} | |
if isinstance(v, str): | |
return type_map[v] | |
if isinstance(v, int): | |
if v < 0 or v > 3: | |
raise ValueError("type_ must be between 0 and 3") | |
return v | |
gtin: constr(max_length=128) | None = None | |
title: constr(max_length=1_000) | |
brands: set[constr(max_length=1_000)] = Field(max_length=30) | |
description: constr(max_length=5_000) | None | |
attributes: ( | |
dict[constr(pattern=r"[a-zA-Z0-9][a-zA-Z0-9_]*", max_length=128), CustomAttributeText | CustomAttributeNumbers] | |
| None | |
) = Field(None, max_length=200) | |
price_info: PriceInfo | |
rating: Rating | |
availability: Literal["AVAILABILITY_UNSPECIFIED", "IN_STOCK", "OUT_OF_STOCK", "PREORDER", "BACKORDER"] | int | |
@field_validator("availability", mode="before") | |
@classmethod | |
def validate_availability(cls, v): | |
availability_map = { | |
"AVAILABILITY_UNSPECIFIED": 0, | |
"IN_STOCK": 1, | |
"OUT_OF_STOCK": 2, | |
"PREORDER": 3, | |
"BACKORDER": 4, | |
} | |
if isinstance(v, str): | |
return availability_map[v] | |
if isinstance(v, int): | |
if v < 0 or v > 4: | |
raise ValueError("availability must be between 0 and 4") | |
return v | |
available_quantity: int = Field(ge=0) | |
fulfillment_info: FulfillmentInfo | None = None | |
uri: constr(max_length=5_000) | |
images: list[Image] | None = Field(None, max_length=300) | |
audience: Audience | |
color_info: ColorInfo | |
sizes: set[constr(max_length=128)] | None = Field(None, max_length=20) | |
materials: set[constr(max_length=200)] | None = Field(None, max_length=20) | |
patterns: set[constr(max_length=128)] | None = Field(None, max_length=20) | |
conditions: set[Literal["new", "refurbished", "used"]] = Field(max_length=1) | |
promotions: set[Promotion] | None = Field(None, max_length=10) | |
class ProductVariant(_Product): | |
type_: Literal["VARIANT"] | int = Field(default="VARIANT", serialization_alias="type") | |
primary_product_id: constr(max_length=128) | |
class ProductPrimary(_Product): | |
type_: Literal["PRIMARY"] | int = Field(default="PRIMARY", serialization_alias="type") | |
categories: set[constr(pattern=r"^(?:[A-Za-z0-9&-äöüÄÖÜß ]+)(?: > [A-Za-z0-9&-äöüÄÖÜß ]+)*$")] = Field( | |
max_length=250 | |
) | |
language_code: str | |
tags: set[constr(max_length=1_000)] | None = Field(None, max_length=250) | |
variants: list[ProductVariant] = Field(max_length=2000) | |
@model_validator(mode="after") | |
def validate(self) -> Self: | |
if self.variants: | |
if not all(variant.primary_product_id == self.id for variant in self.variants): | |
raise ValueError("all variants must have the same primary_product_id") | |
return self | |
class ProductData(RootModel): | |
root: list[ProductPrimary] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment