Created
May 18, 2016 14:24
-
-
Save igul222/765668b05b6cf20e0ebb522959d52a99 to your computer and use it in GitHub Desktop.
1D masked convolutions, a la Pixel RNN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def Conv1D(name, input_dim, output_dim, filter_size, inputs, mask_type=None, he_init=False): | |
""" | |
inputs.shape: (batch size, input_dim, 1, width) | |
mask_type: None, 'a', 'b' | |
output.shape: (batch size, output_dim, 1, width) | |
""" | |
if mask_type is not None: | |
mask = numpy.ones( | |
(output_dim, input_dim, 1, filter_size), | |
dtype=theano.config.floatX | |
) | |
center = filter_size//2 | |
mask[:,:,0,center+1:] = 0. | |
if mask_type == 'a': | |
mask[:,:,0,center] = 0. | |
def uniform(stdev, size): | |
"""uniform distribution with the given stdev and size""" | |
return numpy.random.uniform( | |
low=-stdev * numpy.sqrt(3), | |
high=stdev * numpy.sqrt(3), | |
size=size | |
).astype(theano.config.floatX) | |
if mask_type is not None: | |
n_in = numpy.sum(mask) | |
else: | |
n_in = input_dim * filter_size | |
if he_init: | |
init_stdev = numpy.sqrt(2./n_in) | |
else: | |
init_stdev = numpy.sqrt(1./n_in) | |
filters = lib.param( | |
name+'.Filters', | |
uniform( | |
init_stdev, | |
(output_dim, input_dim, 1, filter_size) | |
) | |
) | |
if mask_type is not None: | |
filters = filters * mask | |
# TODO benchmark against the lasagne 'conv1d' implementations | |
result = T.nnet.conv2d(inputs, filters, border_mode='half', filter_flip=False) | |
biases = lib.param( | |
name+'.Biases', | |
numpy.zeros(output_dim, dtype=theano.config.floatX) | |
) | |
result += biases[None, :, None, None] | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment