from numpy.typing import NDArray
from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer
from typing_extensions import Annotated

import json
import numpy as np


# define serializers and validator
def nd_array_custom_before_validator(x: list[float]) -> NDArray[np.float32]:
    # custome before validation logic
    return np.array(x, np.float32)


def nd_array_to_list(x: NDArray[np.float32]) -> list[float]:
    # custome serialization logic
    return x.tolist()

# define type wrapper
NdArray = Annotated[
    np.ndarray,
    BeforeValidator(nd_array_custom_before_validator),
    PlainSerializer(nd_array_to_list, return_type=list[float]),
]

# the class that will use our annotated type
class Datastore(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    x: NdArray = np.zeros((1,))

if __name__ == "__main__":

    ds = Datastore(x=np.array([1, 2, 3]))
    
    print(ds)
    # field 'x' is a proper numpy array
    # OUT:
    # x=array([1., 2., 3.], dtype=float32)

    print(ds.model_dump_json(indent=4))
    # field 'x' is stored as a list of float
    # OUT:
    # {
    #   "x": [
    #      1.0,
    #      2.0,
    #      3.0
    #   ]
    # }

    print(Datastore(**json.loads(ds.model_dump_json())))
    # field 'x' is reloaded correctly into a numpy array
    # OUT:
    # x=array([1., 2., 3.], dtype=float32)