Skip to content

Instantly share code, notes, and snippets.

@Warvito
Last active April 19, 2022 10:52
Show Gist options
  • Save Warvito/32e7f7499eebc8624438e1f488ee8c9c to your computer and use it in GitHub Desktop.
Save Warvito/32e7f7499eebc8624438e1f488ee8c9c to your computer and use it in GitHub Desktop.
Convolutional layers with masks. Convolutional layers with simple implementation of masks type A and B for autoregressive models.
class MaskedConv2D(keras.layers.Layer):
"""Convolutional layers with masks.
Convolutional layers with simple implementation of masks type A and B for
autoregressive models.
Arguments:
mask_type: one of `"A"` or `"B".`
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
kernel_initializer: Initializer for the `kernel` weights matrix.
bias_initializer: Initializer for the bias vector.
"""
def __init__(self,
mask_type,
filters,
kernel_size,
strides=1,
padding='same',
kernel_initializer='glorot_uniform',
bias_initializer='zeros'):
super(MaskedConv2D, self).__init__()
assert mask_type in {'A', 'B'}
self.mask_type = mask_type
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding.upper()
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
def build(self, input_shape):
self.kernel = self.add_weight('kernel',
shape=(self.kernel_size,
self.kernel_size,
int(input_shape[-1]),
self.filters),
initializer=self.kernel_initializer,
trainable=True)
self.bias = self.add_weight('bias',
shape=(self.filters,),
initializer=self.bias_initializer,
trainable=True)
center = self.kernel_size // 2
mask = np.ones(self.kernel.shape, dtype=np.float32)
mask[center, center + (self.mask_type == 'B'):, :, :] = 0.
mask[center + 1:, :, :, :] = 0.
self.mask = tf.constant(mask, dtype=tf.float32, name='mask')
def call(self, input):
masked_kernel = tf.math.multiply(self.mask, self.kernel)
x = nn.conv2d(input,
masked_kernel,
strides=[1, self.strides, self.strides, 1],
padding=self.padding)
x = nn.bias_add(x, self.bias)
return x
@ccczzzmmm
Copy link

Can you give the complete pixelCNN code?please

@Warvito
Copy link
Author

Warvito commented Mar 24, 2020

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment