Created
January 6, 2021 14:15
-
-
Save dneprDroid/1ba4ef91ad533e32594f882e639a37f1 to your computer and use it in GitHub Desktop.
Convolution test - torch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'nn' | |
torch.setdefaulttensortype('torch.FloatTensor') | |
function test_conv() | |
local kernel_size = 3 | |
local stride = 1 | |
local padding = 1 | |
local layer = nn.SpatialConvolutionMM(64, 64, kernel_size, kernel_size, stride, stride, padding, padding) | |
layer.weight:fill(2.2) -- fill weigths with 2.2 | |
layer.bias:fill(1.2) -- fill weigths with 1.2 | |
local tensor = torch.Tensor(1, 64, 256, 256) | |
tensor:fill(1.3) -- fill tesnor with 1.3 | |
-- print(tensor) | |
local result = layer(tensor) | |
print(string.format("result: shape={ %s }, type='%s'\n", result:size(), result:type()) ) | |
local result_flatten = result:view(result:nElement()) | |
for i = 1, 30 do | |
print(string.format("[%d] %f", i, result_flatten[i])) | |
end | |
end | |
test_conv() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
torch.set_default_tensor_type('torch.FloatTensor') | |
def test_conv(): | |
layer = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True) | |
layer.weight.data.fill_(2.2) | |
layer.bias.data.fill_(1.2) | |
tensor = torch.zeros((1, 64, 256, 256)) | |
tensor.fill_(1.3) | |
# print(tensor) | |
result = layer(tensor) | |
print("[test_conv] result: shape={ %s }, type='%s'\n" % (result.shape, result.type()) ) | |
result_flatten = result.flatten() | |
i = 0 | |
for n in result_flatten: | |
if i >= 30: break | |
number = n.item() | |
print("[%d] %f" % (i+1, number) ) | |
i += 1 | |
test_conv() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment