Skip to content

Commit

Permalink
Merge pull request #14 from avik-pal/accuracy
Browse files Browse the repository at this point in the history
[WIP] Improve Accuracy of Models
  • Loading branch information
MikeInnes authored Jul 5, 2018
2 parents fb15742 + f5b5c18 commit cd3d9ea
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 71 deletions.
10 changes: 5 additions & 5 deletions src/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,23 @@ function densenet_layers()
weights[string(ele)] = convert(Array{Float64, N} where N ,weight[ele])
end
ls = _densenet()
ls[1].weight.data .= weights["conv1_w_0"]
ls[1].weight.data .= weights["conv1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]
ls[2].β.data .= weights["conv1/bn_b_0"]
ls[2].γ.data .= weights["conv1/bn_w_0"]
l = 4
for (c, n) in enumerate([6, 12, 24, 16])
for i in 1:n
for j in [2, 4]
ls[l][i].layer[j].weight.data .= weights["conv$(c+1)_$i/x$(j÷2)_w_0"]
ls[l][i].layer[j-1].β.data .= weights["conv$(c+1)_$i/x$(j÷2)/bn_w_0"]
ls[l][i].layer[j].weight.data .= weights["conv$(c+1)_$i/x$(j÷2)_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]
ls[l][i].layer[j-1].β.data .= weights["conv$(c+1)_$i/x$(j÷2)/bn_b_0"]
ls[l][i].layer[j-1].γ.data .= weights["conv$(c+1)_$i/x$(j÷2)/bn_w_0"]
end
end
l += 2
end
for i in [5, 7, 9] # Transition Block Conv Layers
ls[i][2].weight.data .= weights["conv$(i÷2)_blk_w_0"]
ls[i][1].β.data .= weights["conv$(i÷2)_blk/bn_w_0"]
ls[i][2].weight.data .= weights["conv$(i÷2)_blk_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]
ls[i][1].β.data .= weights["conv$(i÷2)_blk/bn_b_0"]
ls[i][1].γ.data .= weights["conv$(i÷2)_blk/bn_w_0"]
end
ls[end-1].W.data .= transpose(squeeze(weights["fc6_w_0"], (1, 2))) # Dense Layers
Expand Down
18 changes: 9 additions & 9 deletions src/googlenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ function googlenet_layers()
weights[string(ele)] = convert(Array{Float64, N} where N, weight[ele])
end
ls = _googlenet()
ls[1].weight.data .= weights["conv1/7x7_s2_w_0"]; ls[1].bias.data .= weights["conv1/7x7_s2_b_0"]
ls[3].weight.data .= weights["conv2/3x3_reduce_w_0"]; ls[3].bias.data .= weights["conv2/3x3_reduce_b_0"]
ls[4].weight.data .= weights["conv2/3x3_w_0"]; ls[4].bias.data .= weights["conv2/3x3_b_0"]
ls[1].weight.data .= weights["conv1/7x7_s2_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]; ls[1].bias.data .= weights["conv1/7x7_s2_b_0"]
ls[3].weight.data .= weights["conv2/3x3_reduce_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]; ls[3].bias.data .= weights["conv2/3x3_reduce_b_0"]
ls[4].weight.data .= weights["conv2/3x3_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]; ls[4].bias.data .= weights["conv2/3x3_b_0"]
for (a, b) in [(6, "3a"), (7, "3b"), (9, "4a"), (10, "4b"), (11, "4c"), (12, "4d"), (13, "4e"), (15, "5a"), (16, "5b")]
ls[a].path_1.weight.data .= weights["inception_$b/1x1_w_0"]; ls[a].path_1.bias.data .= weights["inception_$b/1x1_b_0"]
ls[a].path_2[1].weight.data .= weights["inception_$b/3x3_reduce_w_0"]; ls[a].path_2[1].bias.data .= weights["inception_$b/3x3_reduce_b_0"]
ls[a].path_2[2].weight.data .= weights["inception_$b/3x3_w_0"]; ls[a].path_2[2].bias.data .= weights["inception_$b/3x3_b_0"]
ls[a].path_3[1].weight.data .= weights["inception_$b/5x5_reduce_w_0"]; ls[a].path_3[1].bias.data .= weights["inception_$b/5x5_reduce_b_0"]
ls[a].path_3[2].weight.data .= weights["inception_$b/5x5_w_0"]; ls[a].path_3[2].bias.data .= weights["inception_$b/5x5_b_0"]
ls[a].path_4[2].weight.data .= weights["inception_$b/pool_proj_w_0"]; ls[a].path_4[2].bias.data .= weights["inception_$b/pool_proj_b_0"]
ls[a].path_1.weight.data .= weights["inception_$b/1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]; ls[a].path_1.bias.data .= weights["inception_$b/1x1_b_0"]
ls[a].path_2[1].weight.data .= weights["inception_$b/3x3_reduce_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]; ls[a].path_2[1].bias.data .= weights["inception_$b/3x3_reduce_b_0"]
ls[a].path_2[2].weight.data .= weights["inception_$b/3x3_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]; ls[a].path_2[2].bias.data .= weights["inception_$b/3x3_b_0"]
ls[a].path_3[1].weight.data .= weights["inception_$b/5x5_reduce_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]; ls[a].path_3[1].bias.data .= weights["inception_$b/5x5_reduce_b_0"]
ls[a].path_3[2].weight.data .= weights["inception_$b/5x5_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]; ls[a].path_3[2].bias.data .= weights["inception_$b/5x5_b_0"]
ls[a].path_4[2].weight.data .= weights["inception_$b/pool_proj_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]; ls[a].path_4[2].bias.data .= weights["inception_$b/pool_proj_b_0"]
end
ls[20].W.data .= transpose(weights["loss3/classifier_w_0"]); ls[20].b.data .= weights["loss3/classifier_b_0"]
Flux.testmode!(ls)
Expand Down
8 changes: 4 additions & 4 deletions src/resnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ function resnet_layers()
weights[string(ele)] = convert(Array{Float64, N} where N, weight[ele])
end
ls = resnet50()
ls[1].weight.data .= weights["gpu_0/conv1_w_0"]
ls[1].weight.data .= weights["gpu_0/conv1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]
count = 2
for j in [3:5, 6:9, 10:15, 16:18]
for p in j
ls[p].conv_layers[1].weight.data .= weights["gpu_0/res$(count)_$(p-j[1])_branch2a_w_0"]
ls[p].conv_layers[2].weight.data .= weights["gpu_0/res$(count)_$(p-j[1])_branch2b_w_0"]
ls[p].conv_layers[3].weight.data .= weights["gpu_0/res$(count)_$(p-j[1])_branch2c_w_0"]
ls[p].conv_layers[1].weight.data .= weights["gpu_0/res$(count)_$(p-j[1])_branch2a_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]
ls[p].conv_layers[2].weight.data .= weights["gpu_0/res$(count)_$(p-j[1])_branch2b_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]
ls[p].conv_layers[3].weight.data .= weights["gpu_0/res$(count)_$(p-j[1])_branch2c_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:]
end
count += 1
end
Expand Down
74 changes: 37 additions & 37 deletions src/squeezenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,56 @@ function squeezenet_layers()
for ele in keys(weight)
weights[string(ele)] = weight[ele]
end
c_1 = Conv(weights["conv10_w_0"], weights["conv10_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_1 = Conv(weights["conv10_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["conv10_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_2 = Dropout(0.5f0)
c_3 = Conv(weights["fire9/expand1x1_w_0"], weights["fire9/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_4 = Conv(weights["fire9/squeeze1x1_w_0"], weights["fire9/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_5 = Conv(weights["fire8/expand1x1_w_0"], weights["fire8/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_6 = Conv(weights["fire8/squeeze1x1_w_0"], weights["fire8/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_7 = Conv(weights["fire7/expand1x1_w_0"], weights["fire7/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_8 = Conv(weights["fire7/squeeze1x1_w_0"], weights["fire7/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_9 = Conv(weights["fire6/expand1x1_w_0"], weights["fire6/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_10 = Conv(weights["fire6/squeeze1x1_w_0"], weights["fire6/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_11 = Conv(weights["fire5/expand1x1_w_0"], weights["fire5/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_12 = Conv(weights["fire5/squeeze1x1_w_0"], weights["fire5/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_13 = Conv(weights["fire4/expand1x1_w_0"], weights["fire4/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_14 = Conv(weights["fire4/squeeze1x1_w_0"], weights["fire4/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_15 = Conv(weights["fire3/expand1x1_w_0"], weights["fire3/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_16 = Conv(weights["fire3/squeeze1x1_w_0"], weights["fire3/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_17 = Conv(weights["fire2/expand1x1_w_0"], weights["fire2/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_18 = Conv(weights["fire2/squeeze1x1_w_0"], weights["fire2/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_19 = Conv(weights["conv1_w_0"], weights["conv1_b_0"], stride=(2, 2), pad=(0, 0), dilation = (1, 1))
c_20 = Conv(weights["fire2/expand3x3_w_0"], weights["fire2/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_21 = Conv(weights["fire3/expand3x3_w_0"], weights["fire3/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_22 = Conv(weights["fire4/expand3x3_w_0"], weights["fire4/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_23 = Conv(weights["fire5/expand3x3_w_0"], weights["fire5/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_24 = Conv(weights["fire6/expand3x3_w_0"], weights["fire6/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_25 = Conv(weights["fire7/expand3x3_w_0"], weights["fire7/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_26 = Conv(weights["fire8/expand3x3_w_0"], weights["fire8/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_27 = Conv(weights["fire9/expand3x3_w_0"], weights["fire9/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
ls = Chain(Conv(weights["conv1_w_0"], weights["conv1_b_0"], stride=(2, 2), pad=(0, 0), dilation = (1, 1)),
c_3 = Conv(weights["fire9/expand1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire9/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_4 = Conv(weights["fire9/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire9/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_5 = Conv(weights["fire8/expand1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire8/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_6 = Conv(weights["fire8/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire8/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_7 = Conv(weights["fire7/expand1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire7/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_8 = Conv(weights["fire7/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire7/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_9 = Conv(weights["fire6/expand1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire6/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_10 = Conv(weights["fire6/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire6/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_11 = Conv(weights["fire5/expand1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire5/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_12 = Conv(weights["fire5/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire5/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_13 = Conv(weights["fire4/expand1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire4/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_14 = Conv(weights["fire4/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire4/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_15 = Conv(weights["fire3/expand1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire3/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_16 = Conv(weights["fire3/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire3/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_17 = Conv(weights["fire2/expand1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire2/expand1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_18 = Conv(weights["fire2/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire2/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1))
c_19 = Conv(weights["conv1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["conv1_b_0"], stride=(2, 2), pad=(0, 0), dilation = (1, 1))
c_20 = Conv(weights["fire2/expand3x3_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire2/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_21 = Conv(weights["fire3/expand3x3_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire3/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_22 = Conv(weights["fire4/expand3x3_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire4/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_23 = Conv(weights["fire5/expand3x3_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire5/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_24 = Conv(weights["fire6/expand3x3_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire6/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_25 = Conv(weights["fire7/expand3x3_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire7/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_26 = Conv(weights["fire8/expand3x3_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire8/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))
c_27 = Conv(weights["fire9/expand3x3_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire9/expand3x3_b_0"], stride=(1, 1), pad=(1, 1), dilation = (1, 1))

ls = Chain(Conv(weights["conv1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["conv1_b_0"], stride=(2, 2), pad=(0, 0), dilation = (1, 1)),
x -> relu.(x), x->maxpool(x, (3,3), pad=(0,0), stride=(2,2)),
Conv(weights["fire2/squeeze1x1_w_0"], weights["fire2/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
Conv(weights["fire2/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire2/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
x -> relu.(x), x->cat(3, relu.(c_17(x)), relu.(c_20(x))),
Conv(weights["fire3/squeeze1x1_w_0"], weights["fire3/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
Conv(weights["fire3/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire3/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
x -> relu.(x), x->cat(3, relu.(c_15(x)), relu.(c_21(x))),
x->maxpool(x, (3, 3), pad=(0, 0), stride=(2, 2)),
Conv(weights["fire4/squeeze1x1_w_0"], weights["fire4/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
Conv(weights["fire4/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire4/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
x -> relu.(x), x->cat(3, relu.(c_13(x)), relu.(c_22(x))),
Conv(weights["fire5/squeeze1x1_w_0"], weights["fire5/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
Conv(weights["fire5/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire5/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
x -> relu.(x), x->cat(3, relu.(c_11(x)), relu.(c_23(x))),
x->maxpool(x, (3, 3), pad=(0, 0), stride=(2, 2)),
Conv(weights["fire6/squeeze1x1_w_0"], weights["fire6/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
Conv(weights["fire6/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire6/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
x -> relu.(x), x->cat(3, relu.(c_9(x)), relu.(c_24(x))),
Conv(weights["fire7/squeeze1x1_w_0"], weights["fire7/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
Conv(weights["fire7/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire7/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
x -> relu.(x), x->cat(3, relu.(c_7(x)), relu.(c_25(x))),
Conv(weights["fire8/squeeze1x1_w_0"], weights["fire8/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
Conv(weights["fire8/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire8/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
x -> relu.(x), x->cat(3, relu.(c_5(x)), relu.(c_26(x))),
Conv(weights["fire9/squeeze1x1_w_0"], weights["fire9/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
Conv(weights["fire9/squeeze1x1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["fire9/squeeze1x1_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
x -> relu.(x), x->cat(3, relu.(c_3(x)), relu.(c_27(x))),
Dropout(0.5f0),
Conv(weights["conv10_w_0"], weights["conv10_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
Conv(weights["conv10_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:], weights["conv10_b_0"], stride=(1, 1), pad=(0, 0), dilation = (1, 1)),
x -> relu.(x), x->mean(x, (1,2)),
vec, softmax
)
Expand Down
Loading

0 comments on commit cd3d9ea

Please sign in to comment.