Last active
December 6, 2017 15:09
-
-
Save phizaz/293bed11db89087bfc5afe08f9ee7ede to your computer and use it in GitHub Desktop.
Keras 2D Conv Layer without Conv2D
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MyConv(Layer): | |
''' | |
Implemeting a Conv2D with strides=1, and 'valid' padding | |
''' | |
def __init__(self, filters, kernel, **kwargs): | |
self.filters = filters | |
self.k_h, self.k_w = kernel | |
super(MyConv, self).__init__(**kwargs) | |
def build(self, input_shape): | |
_, self.h, self.w, self.c = input_shape | |
# expected output size | |
self.out_h = self.h - self.k_h + 1 | |
self.out_w = self.w - self.k_w + 1 | |
# allocate vars for kernels | |
self.kernel_size = self.k_h * self.k_w * self.c | |
self.kernels = self.add_weight(name='kernel', | |
shape=[self.k_h, self.k_w, | |
self.c, self.filters], | |
initializer='glorot_uniform', | |
trainable=True) | |
super(MyConv, self).build(input_shape) | |
def call(self, x): | |
# flatten kernels [k_h, k_w, c_in, c_out] -> [k_h * k_w * c_in, c_out] | |
kernel = K.reshape(self.kernels, [self.kernel_size, self.filters]) | |
t = [] | |
for i in range(self.out_h): | |
for j in range(self.out_w): | |
# take a patch | |
p = x[:, i:i + self.k_h, j:j + self.k_w, :] | |
# flatten the patch | |
p = K.reshape(p, [-1, self.kernel_size]) | |
# convolution | |
conv = K.dot(p, kernel) | |
# gather tensors | |
t.append(conv) | |
# list(tensors) -> big tensor | |
stacked = K.stack(t, axis=1) # 900 x [?, 3] -> [?, 900, 3] (stacked on axis=1) | |
print('stacked:', stacked.get_shape()) # [?, 900, 3] | |
# reshape to 4D tensor [n, h, w, c] | |
output = K.reshape(stacked, [-1, self.out_h, self.out_w, self.filters]) | |
print('output:', output.get_shape()) # [?, 30, 30, 3] | |
return output | |
def compute_output_shape(self, input_shape): | |
return (None, self.out_h, self.out_w, self.filters) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment