Skip to content

Instantly share code, notes, and snippets.

@cbonesana
Created February 28, 2025 13:34
Show Gist options
  • Save cbonesana/c09a72a23603699f15a30c90708aced6 to your computer and use it in GitHub Desktop.
Save cbonesana/c09a72a23603699f15a30c90708aced6 to your computer and use it in GitHub Desktop.
Perform serialization of pydantic objects that contains numpy arrays.
from numpy.typing import NDArray
from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer
from typing_extensions import Annotated
import json
import numpy as np
# define serializers and validator
def nd_array_custom_before_validator(x: list[float]) -> NDArray[np.float32]:
# custome before validation logic
return np.array(x, np.float32)
def nd_array_to_list(x: NDArray[np.float32]) -> list[float]:
# custome serialization logic
return x.tolist()
# define type wrapper
NdArray = Annotated[
np.ndarray,
BeforeValidator(nd_array_custom_before_validator),
PlainSerializer(nd_array_to_list, return_type=list[float]),
]
# the class that will use our annotated type
class Datastore(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
x: NdArray = np.zeros((1,))
if __name__ == "__main__":
ds = Datastore(x=np.array([1, 2, 3]))
print(ds)
# field 'x' is a proper numpy array
# OUT:
# x=array([1., 2., 3.], dtype=float32)
print(ds.model_dump_json(indent=4))
# field 'x' is stored as a list of float
# OUT:
# {
# "x": [
# 1.0,
# 2.0,
# 3.0
# ]
# }
print(Datastore(**json.loads(ds.model_dump_json())))
# field 'x' is reloaded correctly into a numpy array
# OUT:
# x=array([1., 2., 3.], dtype=float32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment