Created
January 8, 2024 20:04
-
-
Save bendichter/257051674f8d6ea97fbb732f778dda35 to your computer and use it in GitHub Desktop.
constrainedarray
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple, Type, Any, Union, Optional | |
from pydantic import BaseModel | |
import numpy as np | |
def get_shape(data): | |
""" | |
Get the shape of various data structures: NumPy array, h5py dataset, Zarr dataset, or a nested list of arbitrary depth. | |
Parameters | |
---------- | |
data : {np.ndarray, h5py.Dataset, zarr.core.Array, list} | |
The data structure whose shape is to be determined. | |
Returns | |
------- | |
tuple | |
A tuple representing the shape of the data structure. | |
Raises | |
------ | |
TypeError | |
If the data type is not supported. | |
Examples | |
-------- | |
>>> import numpy as np | |
>>> get_shape(np.array([[1, 2], [3, 4]])) | |
(2, 2) | |
>>> # For h5py and zarr datasets, ensure you have these datasets created in your environment. | |
>>> # Example for a nested list: | |
>>> get_shape([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) | |
(2, 2, 2) | |
""" | |
if hasattr(data, "shape"): | |
return data.shape | |
elif isinstance(data, list): | |
return _get_list_shape(data) | |
else: | |
raise TypeError("Unsupported data type for getting shape.") | |
def _get_list_shape(lst): | |
""" | |
Get the shape of a nested list of arbitrary depth. | |
Parameters | |
---------- | |
lst : list | |
The nested list whose shape is to be determined. | |
Returns | |
------- | |
tuple | |
A tuple representing the shape of the nested list. | |
Notes | |
----- | |
This function assumes that the nested list is regular, i.e., each sub-list at each level has the same length. | |
""" | |
if not lst: | |
return () | |
if isinstance(lst[0], list): | |
return (len(lst),) + _get_list_shape(lst[0]) | |
else: | |
return (len(lst),) | |
class ConstrainedArray: | |
""" | |
A custom Pydantic type for validating NumPy arrays with specific shapes and data types. | |
Attributes | |
---------- | |
shape : Tuple[Union[Tuple[Union[int, None], ...], Union[int, None]]] | |
A list of expected shapes for the NumPy array, where `None` in any position within a shape allows any size for that dimension. | |
dtype : Union[Type, Tuple[Type, ...]] | |
The expected or allowed data type(s) for the NumPy array's elements. Can be a single data type or a tuple of data types. | |
Examples | |
-------- | |
>>> class Model(BaseModel): | |
... array: ConstrainedArray(shape=[(2, 2), (3, 3)], dtype=(np.float32, np.float64)) | |
... | |
>>> model = Model(array=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) | |
>>> print(model) | |
""" | |
def __init__( | |
self, | |
shapes: Optional[Tuple[Tuple[Union[int, None], ...]]] = None, | |
dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, | |
): | |
self.shapes = shapes | |
self.dtype = dtype | |
@classmethod | |
def __get_validators__(cls): | |
yield cls.validate | |
@classmethod | |
def validate(cls, value: Any) -> Any: | |
# Check if the array's dtype is in the allowed dtype(s) | |
if cls.dtype is not None: | |
if isinstance(cls.dtype, tuple): | |
if not any(np.issubdtype(value.dtype, dt) for dt in cls.dtype): | |
allowed_dtypes = ', '.join(dt.__name__ for dt in cls.dtype) | |
raise TypeError(f'Expected array with dtype(s) {allowed_dtypes}, got {value.dtype.name}') | |
else: | |
if not np.issubdtype(value.dtype, cls.dtype): | |
raise TypeError(f'Expected array with dtype {cls.dtype.__name__}, got {value.dtype.name}') | |
# Check if the array's shape matches any of the allowed shapes | |
if cls.shapes is not None: | |
this_shape = get_data_shape(value) | |
if cls.shapes[0] is None or isinstance(cls.shapes[0], int): | |
if not cls._check_single_shape(this_shape, cls.shapes): | |
raise ValueError(f'he provided array shape: {value.shape} does not match the expected shape: {cls.shapes}') | |
elif not cls._check_multi_shape(this_shape, cls.shapes): | |
allowed_shapes = ', '.join(str(s) for s in cls.shapes) | |
raise ValueError(f'None of the expected shapes {allowed_shapes} match the provided array shape {value.shape}') | |
return value | |
@classmethod | |
def _check_multi_shape(cls, shape: Tuple[int, ...], list_of_allowable_shapes: Tuple[Tuple[Union[int, None], ...]]): | |
return any(cls._check_single_shape(shape, allowed_shape) for allowed_shape in list_of_allowable_shapes) | |
@classmethod | |
def _check_single_shape(cls, shape: Tuple[int, ...], allowed_shape: Tuple[Union[int, None], ...]): | |
return all(s2 is None or s1 == s2 for s1, s2 in zip(shape, allowed_shape)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment