Created
February 28, 2025 13:34
-
-
Save cbonesana/c09a72a23603699f15a30c90708aced6 to your computer and use it in GitHub Desktop.
Perform serialization of pydantic objects that contains numpy arrays.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy.typing import NDArray | |
from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer | |
from typing_extensions import Annotated | |
import json | |
import numpy as np | |
# define serializers and validator | |
def nd_array_custom_before_validator(x: list[float]) -> NDArray[np.float32]: | |
# custome before validation logic | |
return np.array(x, np.float32) | |
def nd_array_to_list(x: NDArray[np.float32]) -> list[float]: | |
# custome serialization logic | |
return x.tolist() | |
# define type wrapper | |
NdArray = Annotated[ | |
np.ndarray, | |
BeforeValidator(nd_array_custom_before_validator), | |
PlainSerializer(nd_array_to_list, return_type=list[float]), | |
] | |
# the class that will use our annotated type | |
class Datastore(BaseModel): | |
model_config = ConfigDict(arbitrary_types_allowed=True) | |
x: NdArray = np.zeros((1,)) | |
if __name__ == "__main__": | |
ds = Datastore(x=np.array([1, 2, 3])) | |
print(ds) | |
# field 'x' is a proper numpy array | |
# OUT: | |
# x=array([1., 2., 3.], dtype=float32) | |
print(ds.model_dump_json(indent=4)) | |
# field 'x' is stored as a list of float | |
# OUT: | |
# { | |
# "x": [ | |
# 1.0, | |
# 2.0, | |
# 3.0 | |
# ] | |
# } | |
print(Datastore(**json.loads(ds.model_dump_json()))) | |
# field 'x' is reloaded correctly into a numpy array | |
# OUT: | |
# x=array([1., 2., 3.], dtype=float32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment