Created
April 4, 2020 14:34
-
-
Save michelkana/b7033708206fb589c7e094052e237bd8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Conv2D, BatchNormalization, Input, GlobalAveragePooling2D, Dense | |
from keras.models import Model | |
from keras.layers.advanced_activations import LeakyReLU | |
# function for building the discriminator layers | |
def build_discriminator(start_filters, spatial_dim, filter_size): | |
# function for building a CNN block for downsampling the image | |
def add_discriminator_block(x, filters, filter_size): | |
x = Conv2D(filters, filter_size, padding='same')(x) | |
x = BatchNormalization()(x) | |
x = Conv2D(filters, filter_size, padding='same', strides=2)(x) | |
x = BatchNormalization()(x) | |
x = LeakyReLU(0.3)(x) | |
return x | |
# input is an image with shape spatial_dim x spatial_dim and 3 channels | |
inp = Input(shape=(spatial_dim, spatial_dim, 3)) | |
# design the discrimitor to downsample the image 4x | |
x = add_discriminator_block(inp, start_filters, filter_size) | |
x = add_discriminator_block(x, start_filters * 2, filter_size) | |
x = add_discriminator_block(x, start_filters * 4, filter_size) | |
x = add_discriminator_block(x, start_filters * 8, filter_size) | |
# average and return a binary output | |
x = GlobalAveragePooling2D()(x) | |
x = Dense(1, activation='sigmoid')(x) | |
return Model(inputs=inp, outputs=x) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment