forked from jcjohnson/neural-style
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loadcaffe_wrapper.lua
80 lines (72 loc) · 2.32 KB
/
loadcaffe_wrapper.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
local ffi = require 'ffi'
require 'loadcaffe'
local C = loadcaffe.C
--[[
Most of this function is copied from
https://github.com/szagoruyko/loadcaffe/blob/master/loadcaffe.lua
with some horrible horrible hacks added by Justin Johnson to
make it possible to load VGG-19 without any CUDA dependency.
--]]
local function loadcaffe_load(prototxt_name, binary_name, backend)
local backend = backend or 'nn'
local handle = ffi.new('void*[1]')
-- loads caffe model in memory and keeps handle to it in ffi
local old_val = handle[1]
C.loadBinary(handle, prototxt_name, binary_name)
if old_val == handle[1] then return end
-- transforms caffe prototxt to torch lua file model description and
-- writes to a script file
local lua_name = prototxt_name..'.lua'
-- C.loadBinary creates a .lua source file that builds up a table
-- containing the layers of the network. As a horrible dirty hack,
-- we'll modify this file when backend "nn-cpu" is requested by
-- doing the following:
--
-- (1) Delete the lines that import cunn and inn, which are always
-- at lines 2 and 4
local model = nil
if backend == 'nn-cpu' then
C.convertProtoToLua(handle, lua_name, 'nn')
local lua_name_cpu = prototxt_name..'.cpu.lua'
local fin = assert(io.open(lua_name), 'r')
local fout = assert(io.open(lua_name_cpu, 'w'))
local line_num = 1
while true do
local line = fin:read('*line')
if line == nil then break end
fout:write(line, '\n')
line_num = line_num + 1
end
fin:close()
fout:close()
model = dofile(lua_name_cpu)
else
C.convertProtoToLua(handle, lua_name, backend)
model = dofile(lua_name)
end
-- goes over the list, copying weights from caffe blobs to torch tensor
local net = nn.Sequential()
local list_modules = model
for i,item in ipairs(list_modules) do
item[2].name = item[1]
if item[2].weight then
local w = torch.FloatTensor()
local bias = torch.FloatTensor()
C.loadModule(handle, item[1], w:cdata(), bias:cdata())
if backend == 'ccn2' then
w = w:permute(2,3,4,1)
end
item[2].weight:copy(w)
item[2].bias:copy(bias)
end
net:add(item[2])
end
C.destroyBinary(handle)
if backend == 'cudnn' or backend == 'ccn2' then
net:cuda()
end
return net
end
return {
load = loadcaffe_load
}