Skip to content

Instantly share code, notes, and snippets.

@heolin
Created February 18, 2022 21:21
Show Gist options
  • Save heolin/cda0d8b0b4d1e227dce04cc02baba82c to your computer and use it in GitHub Desktop.
Save heolin/cda0d8b0b4d1e227dce04cc02baba82c to your computer and use it in GitHub Desktop.
Test typed datasets
import numpy as np
from typing import Text, List, TypeVar, Generic, Protocol
T = TypeVar('T')
@dataclass
class ItemPotential(Protocol):
item_id: Text
gmv_eur: float
estimated_transactions: float
@dataclass
class DataSet(Generic[T]):
fields: List[Text]
item_id_column: Text = "item_id"
_items: List[T] = field(default_factory=list)
_mapping: Dict[Text, int] = field(default_factory=dict)
_data: np.ndarray = field(default=None)
def __post_init__(self):
self._data = np.empty((0, len(self.fields)))
self._field_to_index = {f: i for (i, f) in enumerate(self.fields)}
def add(self, item: T):
self._items.append(item)
index = len(self._items)-1
self._mapping[self._get_item_id(item)] = index
vector = np.array([getattr(item, f) for f in self.fields])
self._data = np.append(self._data, [vector], axis=0)
return index
def remove(self, item: T):
item_id = self._get_item_id(item)
index = self._mapping[item_id]
self._data = np.delete(self._data, (index), axis=0)
del self._mapping[item_id]
def _get_item_id(self, item: T) -> Text:
return getattr(item, self.item_id_column)
def __getitem__(self, key):
return self._data[:, self._field_to_index[key]]
def __len__(self):
return self._data.shape[0]
def __iter__(self):
return iter(self._data)
def create_from_dataframe(df: pd.DataFrame,
columns=["gmv_eur", "estimated_transactions"]) -> DataSet[ItemPotential]:
dataset = DataSet[ItemPotential](columns)
for _, row in df.iterrows():
dataset.add(row)
return dataset
dataset = create_from_dataframe(_df)
dataset.remove(row)
dataset['gmv_eur'].mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment