Created
August 10, 2017 10:26
-
-
Save melgor/0e43cadf742fe3336148ab64dd63138f to your computer and use it in GitHub Desktop.
LinkNet implemenation in TensorFlow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.contrib.layers.python.layers import initializers | |
slim = tf.contrib.slim | |
''' | |
============================================================================ | |
LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation | |
============================================================================ | |
Based on the paper: https://arxiv.org/pdf/1707.03718.pdf | |
''' | |
#TODO: net initialization | |
@slim.add_arg_scope | |
def convBnRelu(input, num_channel, kernel_size, stride, is_training, scope, padding = 'SAME'): | |
x = slim.conv2d(input, num_channel, [kernel_size, kernel_size], stride=stride, activation_fn=None, scope=scope+'_conv1', padding = padding) | |
x = slim.batch_norm(x, is_training=is_training, fused=True, scope=scope+'_batchnorm1') | |
x = tf.nn.relu(x, name=scope+'_relu1') | |
return x | |
@slim.add_arg_scope | |
def deconvBnRelu(input, num_channel, kernel_size, stride, is_training, scope, padding = 'VALID'): | |
x = slim.conv2d_transpose(input, num_channel, [kernel_size, kernel_size], stride=stride, activation_fn=None, scope=scope+'_fullconv1', padding = padding) | |
x = slim.batch_norm(x, is_training=is_training, fused=True, scope=scope+'_batchnorm1') | |
x = tf.nn.relu(x, name=scope+'_relu1') | |
return x | |
@slim.add_arg_scope | |
def initial_block(inputs, is_training=True, scope='initial_block'): | |
''' | |
The initial block for Linknet has 2 branches: The convolution branch and Maxpool branch. | |
INPUTS: | |
- inputs(Tensor): A 4D tensor of shape [batch_size, height, width, channels] | |
OUTPUTS: | |
- net_concatenated(Tensor): a 4D Tensor that contains the | |
''' | |
#Convolutional branch | |
net_conv = slim.conv2d(inputs, 64, [7,7], stride=2, activation_fn=None, scope=scope+'_conv') | |
net_conv = slim.batch_norm(net_conv, is_training=is_training, fused=True, scope=scope+'_batchnorm') | |
net_conv = tf.nn.relu(net_conv, name=scope+'_relu') | |
#Max pool branch | |
net_pool = slim.max_pool2d(net_conv, [3,3], stride=2, scope=scope+'_max_pool') | |
return net_conv | |
@slim.add_arg_scope | |
def residualBlock(input, n_filters, is_training, stride=1, downsample= None, scope='residualBlock'): | |
# Shortcut connection | |
# Downsample the data or just pass original | |
if downsample == None: | |
shortcut = input | |
else: | |
shortcut = downsample | |
# Residual | |
x = convBnRelu(input, n_filters, kernel_size = 3, stride = stride, is_training = is_training, scope = scope + '/cvbnrelu') | |
x = slim.conv2d(x, n_filters, [3,3], stride=1, activation_fn=None, scope=scope+'_conv2', padding = 'SAME') | |
x = slim.batch_norm(x, is_training=is_training, fused=True, scope=scope+'_batchnorm2') | |
# Shortcutr connection | |
x = x + shortcut | |
x = tf.nn.relu(x, name=scope+'_relu2') | |
return x | |
@slim.add_arg_scope | |
def encoder(inputs, inplanes, planes, blocks, stride, is_training=True, scope='encoder'): | |
''' | |
Decoder of LinkNet | |
INPUTS: | |
- inputs(Tensor): A 4D tensor of shape [batch_size, height, width, channels] | |
OUTPUTS: | |
- net_concatenated(Tensor): a 4D Tensor that contains the | |
''' | |
# make downsample at skip connection if needed | |
downsample = None | |
if stride != 1 or inplanes != planes: | |
downsample = slim.conv2d(inputs, planes, [1,1], stride=stride, activation_fn=None, scope=scope+'_conv_downsample') | |
downsample = slim.batch_norm(downsample, is_training=is_training, fused=True, scope=scope+'_batchnorm_downsample') | |
# Create mupliple block of ResNet | |
output = residualBlock(inputs, planes, is_training, stride, downsample, scope = scope +'/residualBlock0') | |
for i in range(1, blocks): | |
output = residualBlock(output, planes, is_training, 1, scope = scope +'/residualBlock{}'.format(i)) | |
return output | |
@slim.add_arg_scope | |
def decoder(inputs, n_filters, planes, is_training=True, scope='decoder'): | |
''' | |
Encoder use ResNet block. As in paper, we will use ResNet18 block for learning. | |
INPUTS: | |
- inputs(Tensor): A 4D tensor of shape [batch_size, height, width, channels] | |
OUTPUTS: | |
- net_concatenated(Tensor): a 4D Tensor that contains the | |
''' | |
x = convBnRelu(inputs, n_filters/2, kernel_size = 1, stride = 1, is_training = is_training, padding = 'SAME', scope = scope + "/c1") | |
x = deconvBnRelu(x, n_filters/2, kernel_size = 3, stride = 2, is_training = is_training, padding = 'SAME', scope = scope+ "/dc1") | |
x = convBnRelu(x, planes, kernel_size = 1, stride = 1, is_training = is_training, padding = 'SAME', scope = scope+ "/c2") | |
return x | |
#Now actually start building the network | |
def LinkNet(inputs, | |
num_classes, | |
reuse=None, | |
is_training=True, | |
feature_scale=4, | |
scope='LinkNet'): | |
''' | |
The ENet model for real-time semantic segmentation! | |
INPUTS: | |
- inputs(Tensor): a 4D Tensor of shape [batch_size, image_height, image_width, num_channels] that represents one batch of preprocessed images. | |
- num_classes(int): an integer for the number of classes to predict. This will determine the final output channels as the answer. | |
- reuse(bool): Whether or not to reuse the variables for evaluation. | |
- is_training(bool): if True, switch on batch_norm and prelu only during training, otherwise they are turned off. | |
- scope(str): a string that represents the scope name for the variables. | |
OUTPUTS: | |
- net(Tensor): a 4D Tensor output of shape [batch_size, image_height, image_width, num_classes], where each pixel has a one-hot encoded vector | |
determining the label of the pixel. | |
''' | |
#Set the shape of the inputs first to get the batch_size information | |
inputs_shape = inputs.get_shape().as_list() | |
# inputs.set_shape(shape=(batch_size, inputs_shape[1], inputs_shape[2], inputs_shape[3])) | |
layers = [2, 2, 2, 2] | |
filters = [64, 128, 256, 512] | |
filters = [x / feature_scale for x in filters] | |
with tf.variable_scope(scope, reuse=reuse): | |
#Set the primary arg scopes. Fused batch_norm is faster than normal batch norm. | |
with slim.arg_scope([initial_block, encoder], is_training=is_training),\ | |
slim.arg_scope([slim.batch_norm], fused=True), \ | |
slim.arg_scope([slim.conv2d, slim.conv2d_transpose], activation_fn=None): | |
#=================INITIAL BLOCK================= | |
net = initial_block(inputs, scope='initial_block') | |
#===================Encoder======================= | |
enc1 = encoder(net, 64, filters[0], layers[0], stride=1, is_training=is_training, scope='encoder1') | |
enc2 = encoder(enc1, filters[0], filters[1], layers[1], stride=2, is_training=is_training, scope='encoder2') | |
enc3 = encoder(enc2, filters[1], filters[2], layers[2], stride=2, is_training=is_training, scope='encoder3') | |
enc4 = encoder(enc3, filters[2], filters[3], layers[3], stride=2, is_training=is_training, scope='encoder4') | |
#===================Decoder======================= | |
decoder4 = decoder(enc4, filters[3], filters[2], is_training=is_training, scope='decoder4') | |
decoder4 += enc3 | |
decoder3 = decoder(decoder4, filters[2], filters[1], is_training=is_training, scope='decoder3') | |
decoder3 += enc2 | |
decoder2 = decoder(decoder3, filters[1], filters[0], is_training=is_training, scope='decoder2') | |
decoder2 += enc1 | |
decoder1 = decoder(decoder2, filters[0], filters[0], is_training=is_training, scope='decoder1') | |
#===================Final Classification======================= | |
f1 = deconvBnRelu(decoder1, 32/feature_scale, 3, stride = 2, is_training=is_training, scope='f1',padding = 'SAME') | |
f2 = convBnRelu(f1, 32/feature_scale, 3, stride = 1, is_training=is_training, padding = 'SAME', scope='f2') | |
logits = slim.conv2d(f2, num_classes, [2,2], stride=2, activation_fn=None, padding = 'SAME', scope='logits') | |
return logits | |
def LinkNet_arg_scope(weight_decay=2e-4, | |
batch_norm_decay=0.1, | |
batch_norm_epsilon=0.001): | |
''' | |
The arg scope for enet model. The weight decay is 2e-4 as seen in the paper. | |
Batch_norm decay is 0.1 (momentum 0.1) according to official implementation. | |
INPUTS: | |
- weight_decay(float): the weight decay for weights variables in conv2d and separable conv2d | |
- batch_norm_decay(float): decay for the moving average of batch_norm momentums. | |
- batch_norm_epsilon(float): small float added to variance to avoid dividing by zero. | |
OUTPUTS: | |
- scope(arg_scope): a tf-slim arg_scope with the parameters needed for xception. | |
''' | |
# Set weight_decay for weights in conv2d and separable_conv2d layers. | |
with slim.arg_scope([slim.conv2d], | |
weights_regularizer=slim.l2_regularizer(weight_decay), | |
biases_regularizer=slim.l2_regularizer(weight_decay)): | |
# Set parameters for batch_norm. | |
with slim.arg_scope([slim.batch_norm], | |
decay=batch_norm_decay, | |
epsilon=batch_norm_epsilon) as scope: | |
return scope |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment