-
-
Save Seanny123/1d63e493686dac41f4bcfe8f9f0aae81 to your computer and use it in GitHub Desktop.
Keras masking example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Masking, Dense | |
from keras.layers.recurrent import LSTM | |
from keras.models import Sequential | |
import numpy as np | |
np.set_printoptions(precision=4) | |
np.random.seed(1) | |
def print_output_val(out_val): | |
print("\n'--> time'") | |
np.linspace(0, n_timesteps - 1, n_timesteps) | |
for sample in range(n_samples): | |
print('# sample = ', sample) | |
print('input:') | |
print(data[sample, :, :].T) | |
print('output_val:') | |
print(out_val[sample, :, :].T) | |
def sequential_non_temporal_example(): | |
model = Sequential() | |
model.add(Masking(mask_value=mask_value, input_shape=(n_timesteps, n_features))) | |
model.add(Dense(1, activation='linear', kernel_initializer="one")) | |
output_val = model.predict(data) | |
print_output_val(output_val) | |
def sequential_temporal_example(): | |
model = Sequential() | |
model.add(Masking(mask_value=mask_value, input_shape=(n_timesteps, n_features))) | |
model.add(LSTM(2, return_sequences=True, kernel_initializer="one")) | |
model.add(Dense(1, activation='linear', kernel_initializer="one")) | |
output_val = model.predict(data) | |
print_output_val(output_val) | |
n_samples = 3 | |
n_timesteps = 7 | |
n_features = 2 | |
mask_value = np.NaN # -999999999.0# -1.0 # -1 # 0.0 | |
data = np.ones((n_samples, n_timesteps, n_features)) | |
for s in range(n_samples): | |
for f in range(n_features): | |
data[s, :, f] = np.linspace(1, n_timesteps, n_timesteps) | |
# mask a feature value of one sample and timestep (no effect) | |
data[1, 0, 0] = mask_value | |
# mask all feature values of one sample and timestep (propagates 0.*mask_value at layer of step/sample?) | |
data[2, 3, :] = mask_value | |
print('####################### sequential_non_temporal_example #######################:') | |
sequential_non_temporal_example() | |
print('####################### sequential_temporal_example #######################:') | |
# As non-temporal but masked timestep state does not propagate through time: | |
sequential_temporal_example() | |
# ####################### sequential_non_temporal_example #######################: | |
# _________________________________________________________________ | |
# Layer (type) Output Shape Param # | |
# ================================================================= | |
# masking_1 (Masking) (None, 7, 2) 0 | |
# _________________________________________________________________ | |
# dense_1 (Dense) (None, 7, 1) 3 | |
# ================================================================= | |
# Total params: 3.0 | |
# Trainable params: 3.0 | |
# Non-trainable params: 0.0 | |
# _________________________________________________________________ | |
# --> time | |
# [ 0. 1. 2. 3. 4. 5. 6.] | |
# # sample = 0 | |
# input: | |
# [[ 1. 2. 3. 4. 5. 6. 7.] | |
# [ 1. 2. 3. 4. 5. 6. 7.]] | |
# output_val: | |
# [[ 2. 4. 6. 8. 10. 12. 14.]] | |
# # sample = 1 | |
# input: | |
# [[ nan 2. 3. 4. 5. 6. 7.] | |
# [ 1. 2. 3. 4. 5. 6. 7.]] | |
# output_val: | |
# [[ nan 4. 6. 8. 10. 12. 14.]] | |
# # sample = 2 | |
# input: | |
# [[ 1. 2. 3. nan 5. 6. 7.] | |
# [ 1. 2. 3. nan 5. 6. 7.]] | |
# output_val: | |
# [[ 2. 4. 6. nan 10. 12. 14.]] | |
# ####################### sequential_temporal_example #######################: | |
# _________________________________________________________________ | |
# Layer (type) Output Shape Param # | |
# ================================================================= | |
# masking_2 (Masking) (None, 7, 2) 0 | |
# _________________________________________________________________ | |
# lstm_1 (LSTM) (None, 7, 2) 40 | |
# _________________________________________________________________ | |
# dense_2 (Dense) (None, 7, 1) 3 | |
# ================================================================= | |
# Total params: 43.0 | |
# Trainable params: 43.0 | |
# Non-trainable params: 0.0 | |
# _________________________________________________________________ | |
# --> time | |
# [ 0. 1. 2. 3. 4. 5. 6.] | |
# # sample = 0 | |
# input: | |
# [[ 1. 2. 3. 4. 5. 6. 7.] | |
# [ 1. 2. 3. 4. 5. 6. 7.]] | |
# output_val: | |
# [[ 1.2603 1.9066 1.9871 1.9982 1.9998 2. 2. ]] | |
# # sample = 1 | |
# input: | |
# [[ nan 2. 3. 4. 5. 6. 7.] | |
# [ 1. 2. 3. 4. 5. 6. 7.]] | |
# output_val: | |
# [[ nan nan nan nan nan nan nan]] | |
# # sample = 2 | |
# input: | |
# [[ 1. 2. 3. nan 5. 6. 7.] | |
# [ 1. 2. 3. nan 5. 6. 7.]] | |
# output_val: | |
# [[ 1.2603 1.9066 1.9871 nan nan nan nan]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment