Created
February 18, 2022 21:21
-
-
Save heolin/cda0d8b0b4d1e227dce04cc02baba82c to your computer and use it in GitHub Desktop.
Test typed datasets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from typing import Text, List, TypeVar, Generic, Protocol | |
T = TypeVar('T') | |
@dataclass | |
class ItemPotential(Protocol): | |
item_id: Text | |
gmv_eur: float | |
estimated_transactions: float | |
@dataclass | |
class DataSet(Generic[T]): | |
fields: List[Text] | |
item_id_column: Text = "item_id" | |
_items: List[T] = field(default_factory=list) | |
_mapping: Dict[Text, int] = field(default_factory=dict) | |
_data: np.ndarray = field(default=None) | |
def __post_init__(self): | |
self._data = np.empty((0, len(self.fields))) | |
self._field_to_index = {f: i for (i, f) in enumerate(self.fields)} | |
def add(self, item: T): | |
self._items.append(item) | |
index = len(self._items)-1 | |
self._mapping[self._get_item_id(item)] = index | |
vector = np.array([getattr(item, f) for f in self.fields]) | |
self._data = np.append(self._data, [vector], axis=0) | |
return index | |
def remove(self, item: T): | |
item_id = self._get_item_id(item) | |
index = self._mapping[item_id] | |
self._data = np.delete(self._data, (index), axis=0) | |
del self._mapping[item_id] | |
def _get_item_id(self, item: T) -> Text: | |
return getattr(item, self.item_id_column) | |
def __getitem__(self, key): | |
return self._data[:, self._field_to_index[key]] | |
def __len__(self): | |
return self._data.shape[0] | |
def __iter__(self): | |
return iter(self._data) | |
def create_from_dataframe(df: pd.DataFrame, | |
columns=["gmv_eur", "estimated_transactions"]) -> DataSet[ItemPotential]: | |
dataset = DataSet[ItemPotential](columns) | |
for _, row in df.iterrows(): | |
dataset.add(row) | |
return dataset | |
dataset = create_from_dataframe(_df) | |
dataset.remove(row) | |
dataset['gmv_eur'].mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment