Last active
November 21, 2017 23:00
TF Tutorial
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataSet: | |
"""Base data set class | |
""" | |
def __init__(self, shuffle=True, labeled=True, **data_dict): | |
assert '_data' in data_dict | |
if labeled: | |
assert '_labels' in data_dict | |
assert data_dict['_data'].shape[0] == data_dict['_labels'].shape[0] | |
self._labeled = labeled | |
self._shuffle = shuffle | |
self.__dict__.update(data_dict) | |
self._num_samples = self._data.shape[0] | |
self._index_in_epoch = 0 | |
self._epochs_trained = 0 | |
self._batch_number = 0 | |
if self._shuffle: | |
self._shuffle_data() | |
def __len__(self): | |
return len(self._data) + len(self._test_data) | |
@property | |
def epochs_trained(self): | |
return self._epochs_trained | |
@epochs_trained.setter | |
def epochs_trained(self, new_epochs_trained): | |
self._epochs_trained = new_epochs_trained | |
@property | |
def batch_number(self): | |
return self._batch_number | |
@property | |
def index_in_epoch(self): | |
return self._index_in_epoch | |
@property | |
def num_samples(self): | |
return self._num_samples | |
@property | |
def data(self): | |
return self._data | |
@property | |
def labels(self): | |
return self._labels | |
@property | |
def labeled(self): | |
return self._labeled | |
@property | |
def test_data(self): | |
return self._test_data | |
@property | |
def test_labels(self): | |
return self._test_labels | |
@classmethod | |
def load(cls, filename): | |
data_dict = np.load(filename) | |
labeled = data_dict['_labeled'] | |
return cls(labeled=labeled, **data_dict) | |
def save(self, filename): | |
data_dict = self.__dict__ | |
np.savez_compressed(filename, **data_dict) | |
def _shuffle_data(self): | |
shuffled_idx = np.arange(self._num_samples) | |
np.random.shuffle(shuffled_idx) | |
self._data = self._data[shuffled_idx] | |
if self._labeled: | |
self._labels = self._labels[shuffled_idx] | |
def next_batch(self, batch_size): | |
assert batch_size <= self._num_samples | |
start = self._index_in_epoch | |
if start + batch_size > self._num_samples: | |
self._epochs_trained += 1 | |
self._batch_number = 0 | |
data_batch = self._data[start:] | |
if self._labeled: | |
labels_batch = self._labels[start:] | |
remaining = batch_size - (self._num_samples - start) | |
if self._shuffle: | |
self._shuffle_data() | |
start = 0 | |
data_batch = np.concatenate([data_batch, self._data[:remaining]], | |
axis=0) | |
if self._labeled: | |
labels_batch = np.concatenate([labels_batch, | |
self._labels[:remaining]], | |
axis=0) | |
self._index_in_epoch = remaining | |
else: | |
data_batch = self._data[start:start + batch_size] | |
if self._labeled: | |
labels_batch = self._labels[start:start + batch_size] | |
self._index_in_epoch = start + batch_size | |
self._batch_number += 1 | |
batch = (data_batch, labels_batch) if self._labeled else data_batch | |
return batch | |
""" | |
## example | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
filename = 'filename.csv' | |
csv_data = pd.read_csv(filename) | |
train, test = train_test_split(csv_data, train_size=.9) | |
data_dict = {'_data': train, '_test_data': test} | |
data = DataSet(labeled=False, **data_dict) | |
num_steps = 1000 | |
for _ in range(num_steps): | |
batch = data.next_batch(100) | |
# do something with batch | |
""" |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Important TF Guides:
MNIST Beginners Tutorial