Skip to content

Instantly share code, notes, and snippets.

@jrabary
Last active October 27, 2015 09:48
Show Gist options
  • Save jrabary/be32c2acabac880d2921 to your computer and use it in GitHub Desktop.
Save jrabary/be32c2acabac880d2921 to your computer and use it in GitHub Desktop.
Fuel Transformer that samples data per target
class PerClassSampler(Transformer):
def __init__(self, data_stream, class_per_batch, sample_per_class, **kwargs):
super(PerClassSampler, self).__init__(data_stream=data_stream, produces_examples=False, **kwargs)
self.class_per_batch = class_per_batch
self.sample_per_class = sample_per_class
self.total_sample_per_target = self.__total_sample_per_target__(data_stream)
def get_data(self, request=None):
if request is not None:
raise ValueError
current_targets = numpy.random.choice(self.total_sample_per_target.keys(), self.class_per_batch, replace=False)
current_targets_filled = dict([(c, False) for c in current_targets])
current_targets_collected = dict([(c, 0) for c in current_targets])
target_index = self.sources.index('targets')
feature_index = self.sources.index('features')
new_data = [[] for _ in self.sources]
while False in current_targets_filled.values():
data = list(next(self.child_epoch_iterator))
targets = data[target_index]
features = data[feature_index]
matches = numpy.in1d(targets, current_targets)
for i in range(matches.shape[0]):
if matches[i]:
target_value = targets[i]
if current_targets_collected[target_value] != self.total_sample_per_target[target_value]:
current_targets_collected[target_value] += 1
else:
matches[i] = False
new_data[target_index].append(targets[matches])
new_data[feature_index].append(features[matches])
for k, v in current_targets_collected.iteritems():
if v == self.total_sample_per_target[k]:
current_targets_filled[k] = True
return tuple([numpy.concatenate(x) for x in new_data])
def __total_sample_per_target__(self, data_stream):
all_sample = dict()
for _, t in data_stream.get_epoch_iterator():
for t_value in t.flatten():
if t_value in all_sample:
all_sample[t_value] += 1
else:
all_sample[t_value] = 1
for k, v in all_sample.iteritems():
if v > self.sample_per_class:
all_sample[k] = self.sample_per_class
return all_sample
class TestPerClassSampler(TestCase):
def test_get_epoch_iterator(self):
nb_classes = 10
class_per_batch = 3
data_per_class = 4
features = numpy.random.rand(200, 5)
targets = numpy.random.randint(nb_classes, size=200)
scheme = ShuffledScheme(examples=200, batch_size=10)
stream = DataStream(IndexableDataset(OrderedDict([('features', features), ('targets', targets)])),
iteration_scheme=scheme)
sampler = PerClassSampler(stream, class_per_batch, data_per_class)
for epoch in range(10):
batch = 0
for x, t in sampler.get_epoch_iterator():
selected_targets = numpy.unique(t)
self.assertEquals(selected_targets.shape[0], class_per_batch)
count = numpy.bincount(t)
for c in count:
self.assertLessEqual(c, data_per_class)
batch += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment