Created
November 1, 2016 16:25
-
-
Save fmassa/e44159d86340a8a4ee7adfbc0fba60ed to your computer and use it in GitHub Desktop.
Script to recover saved optnet models for which the tensor pointer changed after optimization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local optnet = require 'optnet' | |
local net = torch.load('celeba_24_G.t7') | |
local keys | |
-- the tensor pointers that were saved in the model | |
-- have changed and are not valid anymore. | |
-- In order to try to recover, let's suppose that | |
-- the mapping is given by the offsets in ascending | |
-- order. | |
do | |
local t = {} | |
for k,v in pairs(net.__gradParamsInfo) do | |
table.insert(t, {k, v.offSet}) | |
end | |
local tt = torch.LongTensor(t) | |
local _, v = tt:select(2,2):sort(1) | |
keys = tt:select(2,1):index(1, v) | |
end | |
-- remap the old pointers with the new ones | |
local p, gp = net:parameters() | |
for i = 1, keys:numel() do | |
local ptr = torch.pointer(gp[i]) | |
net.__gradParamsInfo[ptr] = net.__gradParamsInfo[keys[i]] | |
end | |
-- now we should be able to remove the optimization | |
optnet.removeOptimization(net) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment