Created
May 28, 2020 21:07
-
-
Save rafbarr/1670426783b046ffff17c45dd063f45e to your computer and use it in GitHub Desktop.
Custom TFT scale_by_min_max_per_key
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import reduce | |
import apache_beam as beam | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_transform as tft | |
def transform_sparse_values(sp_tensor, trans_fun): | |
return tf.sparse.SparseTensor( | |
indices=sp_tensor.indices, | |
values=trans_fun(sp_tensor.values), | |
dense_shape=sp_tensor.dense_shape | |
) | |
class CollectAll(beam.PTransform): | |
def expand(self, pcoll): | |
pcoll = pcoll | "AddDummyKey" >> beam.Map(lambda v: (None, v)) | |
pcoll = pcoll | "GroupByDummyKey" >> beam.GroupByKey() | |
pcoll = pcoll | "RemoveDummyKey" >> beam.Map(lambda v: v[1]) | |
return pcoll | |
class BasicStatsCombiner(beam.CombineFn): | |
def create_accumulator(self): | |
return (np.inf, -np.inf, 0.0, 0.0, 0) | |
def add_input(self, acc, v): | |
(curr_min, curr_max, curr_sum, curr_sum_sqr, curr_count) = acc | |
new_min = min(curr_min, v) | |
new_max = max(curr_max, v) | |
new_sum = curr_sum + v | |
new_sum_sqr = curr_sum_sqr + v ** 2 | |
new_count = curr_count + 1 | |
return (new_min, new_max, new_sum, new_sum_sqr, new_count) | |
def merge_accumulators(self, accs): | |
min_accs, max_accs, sum_accs, sum_sqr_accs, count_accs = zip(*accs) | |
return reduce(min, min_accs), reduce(max, max_accs), sum(sum_accs), sum(sum_sqr_accs), sum(count_accs) | |
def extract_output(self, acc): | |
(final_min, final_max, final_sum, final_sum_sqr, final_count) = acc | |
if final_count: | |
mean = final_sum / final_count | |
std = np.maximum( | |
np.sqrt(final_sum_sqr / final_count - mean ** 2), | |
np.finfo(np.float64).eps | |
) | |
return { | |
'min': final_min, | |
'max': final_max, | |
'mean': mean, | |
'std': std, | |
'count': final_count | |
} | |
else: | |
return { | |
'min': np.nan, | |
'max': np.nan, | |
'mean': np.nan, | |
'std': np.nan, | |
'count': 0 | |
} | |
class ComputeBasicStatsPerKey(beam.PTransform): | |
_DTYPE_BY_STAT_NAME = { | |
'min': np.float32, | |
'max': np.float32, | |
'mean': np.float32, | |
'std': np.float32, | |
'count': np.int64 | |
} | |
def expand(self, keys_and_values): | |
flattened_keys_and_values = ( | |
keys_and_values | | |
"ZipKeysAndValues" >> beam.FlatMap(lambda v: list(zip(v[0], v[1]))) | |
) | |
stats_per_key = ( | |
flattened_keys_and_values | | |
"ComputeStatsPerKey" >> beam.CombinePerKey(BasicStatsCombiner()) | | |
"CollectStatsPerKey" >> CollectAll() | |
) | |
rets = [] | |
rets.append( | |
stats_per_key | | |
"ExtractKeys" >> beam.Map(lambda s: np.array([v[0] for v in s])) | |
) | |
def extract_stat(stats_per_key, stat_name, stat_dtype): | |
return ( | |
stats_per_key | | |
"Extract{}Stat".format(stat_name.capitalize()) >> beam.Map( | |
lambda s: np.array([v[1][stat_name] for v in s], dtype=stat_dtype) | |
) | |
) | |
for stat_name, stat_dtype in self._DTYPE_BY_STAT_NAME.items(): | |
rets.append(extract_stat(stats_per_key, stat_name, stat_dtype)) | |
return tuple(rets) | |
def get_basic_stats_per_key(dense_keys, dense_values): | |
stats = tft.ptransform_analyzer( | |
[dense_keys, dense_values], | |
# keys, min, max, mean, std, count | |
[dense_keys.dtype, tf.float32, tf.float32, tf.float32, tf.float32, tf.int64], | |
[[None], [None], [None], [None], [None], [None]], | |
ptransform=ComputeBasicStatsPerKey(), | |
name='ComputeBasicStatsPerKey' | |
) | |
return { | |
'keys': stats[0], | |
'min': stats[1], | |
'max': stats[2], | |
'mean': stats[3], | |
'std': stats[4], | |
'count': stats[5] | |
} | |
def scale_by_min_max_per_key( | |
sp_keys, | |
sp_values, | |
output_min=0.0, | |
output_max=1.0, | |
stats_per_key=None, | |
name=None | |
): | |
with tf.compat.v1.name_scope(name, 'scale_by_min_max_per_key'): | |
stats_per_key = stats_per_key or get_basic_stats_per_key( | |
sp_keys.values, sp_values.values | |
) | |
min_lookup = tf.lookup.StaticHashTable( | |
tf.lookup.KeyValueTensorInitializer( | |
stats_per_key['keys'], | |
stats_per_key['min'] | |
), | |
default_value=np.nan, | |
name='min_lookup' | |
) | |
max_lookup = tf.lookup.StaticHashTable( | |
tf.lookup.KeyValueTensorInitializer( | |
stats_per_key['keys'], | |
stats_per_key['max'] | |
), | |
default_value=np.nan, | |
name='max_lookup' | |
) | |
def scale(v): | |
key_min = min_lookup.lookup(sp_keys.values) | |
key_max = max_lookup.lookup(sp_keys.values) | |
return tf.where( | |
key_min < key_max, | |
(v - key_min) / (key_max - key_min) * (output_max - output_min) + output_min, | |
tf.fill(tf.shape(v), (output_min + output_max) / 2.0) | |
) | |
return transform_sparse_values(sp_values, scale) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment