Skip to content

Instantly share code, notes, and snippets.

@ardamavi
Created July 4, 2018 19:24
Show Gist options
  • Save ardamavi/b597e6a7e616819a1adce7c1138b5f19 to your computer and use it in GitHub Desktop.
Save ardamavi/b597e6a7e616819a1adce7c1138b5f19 to your computer and use it in GitHub Desktop.
3D U-Net Model
def get_3d_u_net(data_shape):
inputs = Input(shape=(data_shape))
conv_block_1 = Conv3D(32, (3, 3, 3), strides=(1, 1, 1), padding='same')(inputs)
conv_block_1 = Activation('relu')(conv_block_1)
conv_block_1 = Conv3D(32, (3, 3, 3), strides=(1, 1, 1), padding='same')(conv_block_1)
conv_block_1 = Activation('relu')(conv_block_1)
pool_block_1 = MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))(conv_block_1)
conv_block_2 = Conv3D(64, (3, 3, 3), strides=(1, 1, 1), padding='same')(pool_block_1)
conv_block_2 = Activation('relu')(conv_block_2)
conv_block_2 = Conv3D(64, (3, 3, 3), strides=(1, 1, 1), padding='same')(conv_block_2)
conv_block_2 = Activation('relu')(conv_block_2)
pool_block_2 = MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))(conv_block_2)
conv_block_3 = Conv3D(128, (3, 3, 3), strides=(1, 1, 1), padding='same')(pool_block_2)
conv_block_3 = Activation('relu')(conv_block_3)
conv_block_3 = Conv3D(128, (3, 3, 3), strides=(1, 1, 1), padding='same')(conv_block_3)
conv_block_3 = Activation('relu')(conv_block_3)
pool_block_3 = MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))(conv_block_3)
conv_block_4 = Conv3D(256, (3, 3, 3), strides=(1, 1, 1), padding='same')(pool_block_3)
conv_block_4 = Activation('relu')(conv_block_4)
conv_block_4 = Conv3D(256, (3, 3, 3), strides=(1, 1, 1), padding='same')(conv_block_4)
conv_block_4 = Activation('relu')(conv_block_4)
pool_block_4 = MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))(conv_block_4)
conv_block_5 = Conv3D(256, (3, 3, 3), strides=(1, 1, 1), padding='same')(pool_block_4)
conv_block_5 = Activation('relu')(conv_block_5)
conv_block_5 = Conv3D(256, (3, 3, 3), strides=(1, 1, 1), padding='same')(conv_block_5)
conv_block_5 = Activation('relu')(conv_block_5)
up_block_1 = UpSampling3D((2, 2, 2))(conv_block_5)
up_block_1 = Conv3D(512, (3, 3, 3), strides=(1, 1, 1), padding='same')(up_block_1)
merge_1 = concatenate([conv_block_4, up_block_1])
conv_block_6 = Conv3D(256, (3, 3, 3), strides=(1, 1, 1), padding='same')(merge_1)
conv_block_6 = Activation('relu')(conv_block_6)
conv_block_6 = Conv3D(256, (3, 3, 3), strides=(1, 1, 1), padding='same')(conv_block_6)
conv_block_6 = Activation('relu')(conv_block_6)
up_block_2 = UpSampling3D((2, 2, 2))(conv_block_6)
up_block_2 = Conv3D(256, (3, 3, 3), strides=(1, 1, 1), padding='same')(up_block_2)
merge_2 = concatenate([conv_block_3, up_block_2])
conv_block_7 = Conv3D(128, (3, 3, 3), strides=(1, 1, 1), padding='same')(merge_2)
conv_block_7 = Activation('relu')(conv_block_7)
conv_block_7 = Conv3D(128, (3, 3, 3), strides=(1, 1, 1), padding='same')(conv_block_7)
conv_block_7 = Activation('relu')(conv_block_7)
up_block_3 = UpSampling3D((2, 2, 2))(conv_block_7)
up_block_3 = Conv3D(128, (3, 3, 3), strides=(1, 1, 1), padding='same')(up_block_3)
merge_3 = concatenate([conv_block_2, up_block_3])
conv_block_8 = Conv3D(64, (3, 3, 3), strides=(1, 1, 1), padding='same')(merge_3)
conv_block_8 = Activation('relu')(conv_block_8)
conv_block_8 = Conv3D(64, (3, 3, 3), strides=(1, 1, 1), padding='same')(conv_block_8)
conv_block_8 = Activation('relu')(conv_block_8)
up_block_4 = UpSampling3D((2, 2, 2))(conv_block_8)
up_block_4 = Conv3D(32, (3, 3, 3), strides=(1, 1, 1), padding='same')(up_block_4)
merge_4 = concatenate([conv_block_1, up_block_4])
conv_block_9 = Conv3D(64, (3, 3, 3), strides=(1, 1, 1), padding='same')(merge_4)
conv_block_9 = Activation('relu')(conv_block_9)
conv_block_9 = Conv3D(64, (3, 3, 3), strides=(1, 1, 1), padding='same')(conv_block_9)
conv_block_9 = Activation('relu')(conv_block_9)
conv_block_10 = Conv3D(data_shape[-1], (1, 1, 1), strides=(1, 1, 1), padding='same')(conv_block_9)
outputs = Activation('sigmoid')(conv_block_10)
model = Model(inputs=inputs, outputs=outputs)
try:
model = multi_gpu_model(model)
except:
pass
model.compile(optimizer = 'adadelta', loss=dice_coefficient_loss, metrics=[dice_coefficient])
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment