Created
October 25, 2016 10:05
-
-
Save szagoruyko/e5cf5e9b54661a817695c8c7b5c3dfa6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'xlua' | |
require 'sys' | |
local batches_folder = '/opt/rocks/cifar.torch/cifar-10-batches-t7' | |
local data = {} | |
local labels = {} | |
for i=1,5 do | |
local name = paths.concat(batches_folder, 'data_batch_'..i..'.t7') | |
local part = torch.load(paths.concat(batches_folder, name), 'ascii') | |
table.insert(data, part.data:view(3,32,32,-1)) | |
table.insert(labels, part.labels:squeeze()) | |
end | |
data = torch.ByteTensor.cat(data, 4) | |
labels = torch.ByteTensor.cat(labels) | |
test_part = torch.load(paths.concat(batches_folder, 'test_batch.t7'), 'ascii') | |
test_labels = test_part.labels | |
test_data = test_part.data | |
local dataset = { | |
trainData = { | |
data = data:permute(4,1,2,3):clone(), | |
labels = labels:add(1), | |
size = function() return labels:numel() end, | |
}, | |
testData = { | |
data = test_data:view(3,32,32,-1):permute(4,1,2,3):clone(), | |
labels = test_labels:squeeze():add(1), | |
size = function() return test_labels:numel() end, | |
} | |
} | |
print(dataset) | |
print(dataset.trainData.labels:max()) | |
print(dataset.testData.labels:max()) | |
torch.save('cifar10_original.t7', dataset) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks sharing your code. How many training samples per epoch after data augmentation as mentioned in Experimental Results section?
https://arxiv.org/pdf/1605.07146.pdf