forked from facebookarchive/fb.resnet.torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cifar100-gen.lua
67 lines (53 loc) · 2.03 KB
/
cifar100-gen.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
------------
-- This file automatically downloads the CIFAR-100 dataset from
-- http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz
-- It is based on cifar10-gen.lua
-- Ludovic Trottier
------------
local URL = 'http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz'
local M = {}
local function convertCifar100BinToTorchTensor(inputFname)
local m=torch.DiskFile(inputFname, 'r'):binary()
m:seekEnd()
local length = m:position() - 1
local nSamples = length / 3074 -- 1 coarse-label byte, 1 fine-label byte, 3072 pixel bytes
assert(nSamples == math.floor(nSamples), 'expecting numSamples to be an exact integer')
m:seek(1)
local coarse = torch.ByteTensor(nSamples)
local fine = torch.ByteTensor(nSamples)
local data = torch.ByteTensor(nSamples, 3, 32, 32)
for i=1,nSamples do
coarse[i] = m:readByte()
fine[i] = m:readByte()
local store = m:readByte(3072)
data[i]:copy(torch.ByteTensor(store))
end
local out = {}
out.data = data
-- This is *very* important. The downloaded files have labels 0-9, which do
-- not work with CrossEntropyCriterion
out.labels = fine + 1
return out
end
function M.exec(opt, cacheFile)
print("=> Downloading CIFAR-100 dataset from " .. URL)
local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/')
assert(ok == true or ok == 0, 'error downloading CIFAR-100')
print(" | combining dataset into a single file")
local trainData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/train.bin')
local testData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/test.bin')
print(" | saving CIFAR-100 dataset to " .. cacheFile)
torch.save(cacheFile, {
train = trainData,
val = testData,
})
end
return M