Created
July 3, 2018 14:14
-
-
Save tanzhenyu/22fadcfda66704199a5c5d4edf10c17e to your computer and use it in GitHub Desktop.
test model cloning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# a gist of model cloning (sequential). (functional model cloning should be the same) | |
def to_list(x): | |
if isinstance(x, list): | |
return x | |
else: | |
return [x] | |
def is_keras_tensor(x): | |
return hasattr(x, '_keras_history') | |
def clone_sequential_model(model, input_tensors=None): | |
def clone(layer): | |
return layer.__class__.from_config(layer.get_config()) | |
layers = [clone(layer) for layer in model.layers] | |
if input_tensors is None: | |
return tf.keras.Sequential(layers=layers, name=model.name) | |
else: | |
if len(to_list(input_tensors)) != 1: | |
raise ValueError('To clone a `Sequential` model, we expect ' | |
' at most one tensor ' | |
'as part of `input_tensors`.') | |
x = to_list(input_tensors)[0] | |
if is_keras_tensor(x): | |
origin_layer = x._keras_history[0] | |
if isinstance(origin_layer, InputLayer): | |
return tf.keras.Sequential(layers=[origin_layer] + layers, name=model.name) | |
else: | |
raise ValueError('Cannot clone a `Sequential` model on top ' | |
'of a tensor that comes from a Keras layer ' | |
'other than an `InputLayer`. ' | |
'Use the functional API instead.') | |
input_tensor = tf.keras.Input(tensor=x, name='input_wrapper_for_' + str(x.name)) | |
input_layer = input_tensor._keras_history[0] | |
return tf.keras.Sequential(layers=[input_layer] + layers, name=model.name) | |
import tensorflow as tf | |
tf.reset_default_graph() | |
model = tf.keras.Sequential() | |
model.add(tf.keras.layers.Dense(1, input_dim = (1))) | |
model.add(tf.keras.layers.BatchNormalization()) | |
model.add(tf.keras.layers.Activation('softmax')) | |
x = tf.placeholder(tf.float32, shape = (None, 1)) | |
clone = clone_sequential_model(model, input_tensors=x) | |
model.compile(optimizer = 'sgd', loss = 'categorical_crossentropy', metrics = ['categorical_accuracy']) | |
clone.compile(optimizer = 'sgd', loss = 'categorical_crossentropy', metrics = ['categorical_accuracy']) | |
# dense_input with updates of batch_norm moving_mean and moving_var | |
model._feed_inputs | |
model.get_updates_for(model._feed_inputs) | |
# empty input with empty updates | |
clone._feed_inputs | |
clone.get_updates_for(clone._feed_inputs) | |
# train function with update ops including batch norm assign moving mean/var ops | |
model._make_train_function() | |
print(model.train_function.updates_op) | |
# train function with update ops excluding batch norm assign moving mena/var ops | |
clone._make_train_function() | |
print(clone.train_function.updates_op) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment