Created
September 2, 2016 19:32
-
-
Save culurciello/1955aa1d6ea8381a9da908f9f3d13228 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Eugenio Culurciello | |
-- August 2016 | |
-- a test to learn to code PredNet-like nets in Torch7 | |
require 'nn' | |
require 'nngraph' | |
torch.setdefaulttensortype('torch.FloatTensor') | |
nngraph.setDebug(true) | |
local nlayers = 4 | |
-- local input = nn.Identity()() | |
-- local pOut = nn.Identity()() | |
local inputs = {} | |
local outputs = {} | |
table.insert(inputs, nn.Identity()()) -- input image x | |
for L = 1, nlayers do | |
table.insert(inputs, nn.Identity()()) -- previous output D | |
end | |
for L = 1, nlayers do | |
print('Creating layer-test:', L) | |
-- define layer functions: | |
local cD = nn.MulConstant(2) | |
local cG = nn.MulConstant(0.5) | |
local E = nn.CSubTable(1) | |
local D | |
if L == 1 then | |
D = {inputs[1]} - cD -- output | |
else | |
D = {outputs[2*L-3]} - cD | |
end | |
D:annotate{graphAttributes = {color = 'green', fontcolor = 'green'}} | |
local G = {inputs[L+1]} - cG | |
local Df | |
if L == 1 then | |
Df = {inputs[L], G} - E -- output difference | |
else | |
Df = {outputs[2*L-3], G} - E | |
end | |
Df:annotate{graphAttributes = {color = 'blue', fontcolor = 'blue'}} | |
table.insert(outputs, D) | |
table.insert(outputs, Df) | |
end | |
-- create graph | |
print('Creating model-test:') | |
nngraph.annotateNodes() | |
local model = nn.gModule(inputs, outputs) | |
-- test: | |
print('Testing model-test:') | |
local c = require 'trepl.colorize' | |
local nT = 3 -- time sequence length | |
--local inTable = {} | |
local outTable = {} | |
for L = 1, nlayers * 2 do | |
table.insert(outTable, torch.zeros(2, 2)) | |
end | |
local x = {} -- size nT | |
for t = 1, nT do table.insert(x, torch.ones(2, 2)) end | |
local tmp | |
for i = 1, nT do | |
inTable = {x[i]} -- size (nlayers + 1) | |
for j = 1, nlayers do | |
table.insert(inTable, outTable[2*j - 1]) | |
end | |
print(c.red('Input of iteration '.. i ..' is:')) | |
tmp = inTable[1] | |
for j = 2, #inTable do tmp = torch.cat(tmp, inTable[j], 2) end | |
print(tmp) | |
outTable = model:forward(inTable) -- size 2*nlayers | |
print(c.cyan('Output of iteration '.. i ..' is: ')) | |
tmp = outTable[1] | |
for j = 2, #outTable do tmp = torch.cat(tmp, outTable[j], 2) end | |
print(tmp) | |
end | |
graph.dot(model.fg, 'test','Model-test') -- graph the model! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment