Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save gutzbenj/f8be9fa1dd254b0c4ada582ffe0ed3d3 to your computer and use it in GitHub Desktop.
Save gutzbenj/f8be9fa1dd254b0c4ada582ffe0ed3d3 to your computer and use it in GitHub Desktop.
A pydantic representation of Googles Vertex AI Search for Retail product data model (not included: collection)
"""
Below you find a pydantic model that closely represents Googles product data model for their Vertex AI Search for Retail.
Their model is closely described
- in their restapi overview: https://cloud.google.com/retail/docs/reference/rest/v2/projects.locations.catalogs.branches.products
- as proto message in code: https://github.com/googleapis/google-cloud-python/blob/main/packages/google-cloud-retail/google/cloud/retail_v2/types/product.py
The model so far only supports product types "variant" and "primary" - which is a group of variants e.g. one productId with many skuIds.
Collections are currently not supported!
After validation product data has to be dumped `by_alias`:
product_data.model_dump_json(by_alias=True)
If you want to import data to the Search for Retail you need to provide it as multiline JSON:
"\n".join(product.model_dump_json(by_alias=True) for product in product_data)
Missing pieces:
- The types of currency and langugage code are not fully correct - you need to use actually existing literals there
- expiry_time and ttl need to be fixed, you should only set one of them and ttl should be a duration string.
"""
from typing import Self, Literal, constr
import datetime as dt
from pydantic import BaseModel, Field, model_valiator, field_validator, ValidationInfo
class PriceInterval(BaseModel):
minimum: float
maximum: float
@model_validator(mode="after")
def validate(self) -> Self:
if self.minimum > self.maximum:
raise ValueError(f"minimum ({self.minimum}) must be less or equal to maximum ({self.maximum})")
return self
class PriceIntervalExclusiveMinimum(BaseModel):
exclusive_minimum: float
maximum: float
@model_validator(mode="after")
def validate(self) -> Self:
if self.exclusive_minimum >= self.maximum:
raise ValueError(f"exclusive_minimum ({self.exclusive_minimum}) must be less than maximum ({self.maximum})")
return self
class PriceIntervalExclusiveMaximum(BaseModel):
minimum: float
exclusive_maximum: float
@model_validator(mode="after")
def validate(self) -> Self:
if self.minimum >= self.exclusive_maximum:
raise ValueError(f"minimum ({self.minimum}) must be less than exclusive_maximum ({self.exclusive_maximum})")
return self
class PriceIntervalExclusive(BaseModel):
exclusive_minimum: float
exclusive_maximum: float
@model_validator(mode="after")
def validate(self) -> Self:
if self.exclusive_minimum >= self.exclusive_maximum:
raise ValueError(
f"exclusive_minimum ({self.exclusive_minimum}) must be less than "
f"exclusive_maximum ({self.exclusive_maximum})"
)
return self
class PriceRange(BaseModel):
price: PriceInterval | PriceIntervalExclusiveMinimum | PriceIntervalExclusiveMaximum | PriceIntervalExclusive
original_price: (
PriceInterval | PriceIntervalExclusiveMinimum | PriceIntervalExclusiveMaximum | PriceIntervalExclusive
)
class PriceInfo(BaseModel):
currency_code: str
price: float
original_price: float = 0.0
@field_validator("original_price", mode="after")
@classmethod
def validate_original_price(cls, v, info: ValidationInfo):
if v == 0.0:
return info.data["price"]
if v < info.data["price"]:
raise ValueError("original_price must be greater than price")
return v
cost: float
price_effective_time: dt.datetime | None = None
price_expire_time: dt.datetime | None = None
price_range: PriceRange
class Rating(BaseModel):
rating_count: int
average_rating: float = Field(ge=1, le=5)
rating_histogram: list[int] = Field(min_length=5, max_length=5)
class FulfillmentInfo(BaseModel):
type_: Literal[
"pickup-in-store",
"ship-to-store",
"same-day-delivery",
"next-day-delivery",
"custom-type-1",
"custom-type-2",
"custom-type-3",
"custom-type-4",
"custom-type-5",
]
place_ids: list[str]
class Image(BaseModel):
uri: str
height: int = Field(gt=0)
width: int = Field(gt=0)
class Audience(BaseModel):
genders: set[Literal["male", "female", "unisex"]]
age_groups: set[Literal["newborn", "infant", "toddler", "kids", "adult"]]
class ColorInfo(BaseModel):
color_families: set[
Literal[
"red",
"pink",
"orange",
"yellow",
"purple",
"green",
"cyan",
"blue",
"brown",
"white",
"gray",
"black",
"mixed",
]
] = Field(max_length=5)
colors: set[str]
class CustomAttributeText(BaseModel):
text: set[constr(min_length=1, max_length=256)] = Field(max_length=400)
class CustomAttributeNumbers(BaseModel):
numbers: set[float] = Field(max_length=400)
class Promotion(BaseModel):
promotion_id: str
class _Product(BaseModel):
publish_time: dt.datetime
available_time: dt.datetime
expire_time: dt.datetime
@field_validator("expire_time", mode="after")
@classmethod
def validate_expire_time(cls, v, info: ValidationInfo):
if v < info.data["available_time"]:
raise ValueError("expire_time must be after available_time")
if v < info.data["publish_time"]:
raise ValueError("expire_time must be after publish_time")
return v
ttl: int | None = None
id: constr(max_length=128)
type_: Literal["TYPE_UNSPECIFIED", "PRIMARY", "VARIANT", "COLLECTION"] | int = Field(
default="TYPE_UNSPECIFIED", serialization_alias="type"
)
@field_validator("type_", mode="before")
@classmethod
def validate_type_(cls, v):
type_map = {
"TYPE_UNSPECIFIED": 0,
"PRIMARY": 1,
"VARIANT": 2,
"COLLECTION": 3,
}
if isinstance(v, str):
return type_map[v]
if isinstance(v, int):
if v < 0 or v > 3:
raise ValueError("type_ must be between 0 and 3")
return v
gtin: constr(max_length=128) | None = None
title: constr(max_length=1_000)
brands: set[constr(max_length=1_000)] = Field(max_length=30)
description: constr(max_length=5_000) | None
attributes: (
dict[constr(pattern=r"[a-zA-Z0-9][a-zA-Z0-9_]*", max_length=128), CustomAttributeText | CustomAttributeNumbers]
| None
) = Field(None, max_length=200)
price_info: PriceInfo
rating: Rating
availability: Literal["AVAILABILITY_UNSPECIFIED", "IN_STOCK", "OUT_OF_STOCK", "PREORDER", "BACKORDER"] | int
@field_validator("availability", mode="before")
@classmethod
def validate_availability(cls, v):
availability_map = {
"AVAILABILITY_UNSPECIFIED": 0,
"IN_STOCK": 1,
"OUT_OF_STOCK": 2,
"PREORDER": 3,
"BACKORDER": 4,
}
if isinstance(v, str):
return availability_map[v]
if isinstance(v, int):
if v < 0 or v > 4:
raise ValueError("availability must be between 0 and 4")
return v
available_quantity: int = Field(ge=0)
fulfillment_info: FulfillmentInfo | None = None
uri: constr(max_length=5_000)
images: list[Image] | None = Field(None, max_length=300)
audience: Audience
color_info: ColorInfo
sizes: set[constr(max_length=128)] | None = Field(None, max_length=20)
materials: set[constr(max_length=200)] | None = Field(None, max_length=20)
patterns: set[constr(max_length=128)] | None = Field(None, max_length=20)
conditions: set[Literal["new", "refurbished", "used"]] = Field(max_length=1)
promotions: set[Promotion] | None = Field(None, max_length=10)
class ProductVariant(_Product):
type_: Literal["VARIANT"] | int = Field(default="VARIANT", serialization_alias="type")
primary_product_id: constr(max_length=128)
class ProductPrimary(_Product):
type_: Literal["PRIMARY"] | int = Field(default="PRIMARY", serialization_alias="type")
categories: set[constr(pattern=r"^(?:[A-Za-z0-9&-äöüÄÖÜß ]+)(?: > [A-Za-z0-9&-äöüÄÖÜß ]+)*$")] = Field(
max_length=250
)
language_code: str
tags: set[constr(max_length=1_000)] | None = Field(None, max_length=250)
variants: list[ProductVariant] = Field(max_length=2000)
@model_validator(mode="after")
def validate(self) -> Self:
if self.variants:
if not all(variant.primary_product_id == self.id for variant in self.variants):
raise ValueError("all variants must have the same primary_product_id")
return self
class ProductData(RootModel):
root: list[ProductPrimary]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment