Last active
November 23, 2017 18:02
-
-
Save dela3499/2984d44150a99d647f4fb754b2ccada8 to your computer and use it in GitHub Desktop.
Keras example with helper functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Original | |
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3))) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(64, (3, 3), activation='relu')) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(128, (3, 3), activation='relu')) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(128, (3, 3), activation='relu')) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Flatten()) | |
model.add(layers.Dropout(0.5)) | |
model.add(layers.Dense(512, activation='relu')) | |
model.add(layers.Dense(1, activation='sigmoid')) | |
# With helpers | |
myLayers = [ | |
convPool(32, input_shape=(150,150,3)), | |
convPool(64), | |
convPool(128), | |
convPool(128), | |
layers.Flatten(), | |
layers.Dropout(0.5), | |
layers.Dense(512, activation='relu'), | |
layers.Dense(1, activation='sigmoid') | |
] | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = models.Sequential() | |
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3))) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(64, (3, 3), activation='relu')) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(128, (3, 3), activation='relu')) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(128, (3, 3), activation='relu')) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Flatten()) | |
model.add(layers.Dropout(0.5)) | |
model.add(layers.Dense(512, activation='relu')) | |
model.add(layers.Dense(1, activation='sigmoid')) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
maxPool = layers.MaxPooling2D((2,2)) | |
def conv(nFilters, **kwargs): | |
return layers.Conv2D(nFilters, (3,3), activation = 'relu', **kwargs) | |
def convPool(nFilters, **kwargs): | |
return [conv(nFilters, **kwargs), maxPool] | |
def squash(xs): | |
"""Flatten an irrelegular list of lists (some items aren't lists)""" | |
newList = [] | |
for x in xs: | |
if isinstance(x, list): | |
for sub in squash(x): | |
newList.append(sub) | |
else: | |
newList.append(x) | |
return squash | |
def createModel(modelLayers): | |
model = models.Sequential() | |
for layer in squash(modelLayers): | |
model.add(layer) | |
return model | |
myLayers = [ | |
convPool(32, input_shape=(150,150,3)), | |
convPool(64), | |
convPool(128), | |
convPool(128), | |
layers.Flatten(), | |
layers.Dropout(0.5), | |
layers.Dense(512, activation='relu'), | |
layers.Dense(1, activation='sigmoid') | |
] | |
model = createModel(mylayers) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment