Skip to content

Instantly share code, notes, and snippets.

@Warvito
Last active March 7, 2020 16:39
Show Gist options
  • Save Warvito/ba067d3d6eefed3ae3259cafa7ce689c to your computer and use it in GitHub Desktop.
Save Warvito/ba067d3d6eefed3ae3259cafa7ce689c to your computer and use it in GitHub Desktop.
PixelCNN Architecture
class ResidualBlock(keras.Model):
"""Residual blocks that compose pixelCNN
Blocks of layers with 3 convolutional layers and one residual connection.
Based on Figure 5 from [1] where h indicates number of filters.
Refs:
[1] - Oord, A. V. D., Kalchbrenner, N., & Kavukcuoglu, K. (2016). Pixel
recurrent neural networks. arXiv preprint arXiv:1601.06759.
"""
def __init__(self, h):
super(ResidualBlock, self).__init__(name='')
self.conv2a = keras.layers.Conv2D(filters=h, kernel_size=1, strides=1)
self.conv2b = MaskedConv2D(mask_type='B', filters=h, kernel_size=3, strides=1)
self.conv2c = keras.layers.Conv2D(filters=2 * h, kernel_size=1, strides=1)
def call(self, input_tensor):
x = nn.relu(input_tensor)
x = self.conv2a(x)
x = nn.relu(x)
x = self.conv2b(x)
x = nn.relu(x)
x = self.conv2c(x)
x += input_tensor
return x
# Create PixelCNN model
inputs = keras.layers.Input(shape=(height, width, n_channel))
x = MaskedConv2D(mask_type='A', filters=128, kernel_size=7, strides=1)(inputs)
for i in range(15):
x = ResidualBlock(h=64)(x)
x = keras.layers.Activation(activation='relu')(x)
x = keras.layers.Conv2D(filters=128, kernel_size=1, strides=1)(x)
x = keras.layers.Activation(activation='relu')(x)
x = keras.layers.Conv2D(filters=128, kernel_size=1, strides=1)(x)
x = keras.layers.Conv2D(filters=q_levels, kernel_size=1, strides=1)(x)
pixelcnn = keras.Model(inputs=inputs, outputs=x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment