Created
April 26, 2019 23:27
-
-
Save philip-bl/04b035b81696efee1e15450225b335a5 to your computer and use it in GitHub Desktop.
Linear layer for tensors of any shape in pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ShapyLinear(nn.Module): | |
"""Can model any affine function from the set of tensors of any (fixed) shape `in_shape` to | |
the set of tensors of any (fixed) shape `out_shape`. | |
In forward method the first modes of `inputs` are interpreted as indices of samples, | |
then come the modes corresponding to `in_shape`. The affine function is applied to each sample.""" | |
def __init__(self, in_shape, out_shape): | |
""":param in_shape: shape of one input sample | |
:param out_shape: shape of one output sample""" | |
super().__init__() | |
in_shape = tuple(in_shape) | |
out_shape = tuple(out_shape) | |
self.input_num_modes = len(in_shape) | |
weight_data = torch.zeros(*in_shape, *out_shape, requires_grad=True) | |
xavier_normal_(weight_data) | |
self.weight = nn.Parameter(weight_data) | |
self.bias = nn.Parameter(torch.zeros(*out_shape, requires_grad=True)) | |
@property | |
def weight_contraction_modes(self): | |
"""Indices of modes of `self.weight` over which tensor contraction with input is performed.""" | |
return tuple(range(self.input_num_modes)) | |
def forward(self, inputs): | |
# calculate how many modes of `inputs` represent indices of samples | |
num_sample_modes = inputs.ndimension() - self.input_num_modes | |
# calculate over what modes of `inputs` we perform tensor contraction | |
inputs_contraction_modes = tuple(range(num_sample_modes, inputs.ndimension())) | |
foo = torch.tensordot(inputs, self.weight, dims=(inputs_contraction_modes, self.weight_contraction_modes)) | |
return foo + self.bias | |
def __repr__(self): | |
return f"ShapyLinear(input_num_modes={self.input_num_modes}, weight.shape={tuple(self.weight.shape)})" | |
def test_shapy_linear(): | |
"""Check that it calculates exactly the same thing as `torch.nn.Linear`, except reshaped.""" | |
shapy = ShapyLinear((2, 3), (4, 5)) | |
lin = nn.Linear(2*3, 4*5) | |
lin.weight.data = shapy.weight.permute(2, 3, 0, 1).reshape(4*5, 2*3) | |
lin.bias.data = shapy.bias.reshape(4*5) | |
X = torch.randn(6, 7, 2, 3) | |
result_lin = lin(X.reshape(6*7, 2*3)).reshape(6, 7, 4, 5) | |
result_shapy = shapy(X) | |
assert (result_shapy - result_lin).abs().max().item() < 1e-10 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment