-
Notifications
You must be signed in to change notification settings - Fork 0
/
scratch.py
86 lines (70 loc) · 2.57 KB
/
scratch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import deq_module.deq as deq
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as io
import mon
import NODEN
import numpy
import torch
import train
import splitting as sp
import matplotlib
import lben
matplotlib.use("TkAgg")
if __name__ == "__main__":
dataset = "mnist"
if dataset == "mnist":
trainLoader, testLoader = train.mnist_loaders(train_batch_size=32,
test_batch_size=32)
in_dim = 28
in_channels = 1
elif dataset == "cifar":
trainLoader, testLoader = train.cifar_loaders(train_batch_size=250,
test_batch_size=250)
in_dim = 32
in_channels = 3
elif dataset == "svhn":
trainLoader, testLoader = train.svhn_loaders(train_batch_size=250,
test_batch_size=250)
in_dim = 32
in_channels = 3
load_models = False
alpha = 0.005
gamma = 0.5
epochs = 5
seed = 4
tol = 1E-2
width = 20
lr_decay_steps = 10
max_iter = 1500
m = 0.1
path = './models/conv_experiment/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(seed)
numpy.random.seed(seed)
# Lipschitz network
for gamma in [1.0]:
torch.manual_seed(seed)
numpy.random.seed(seed)
LipConvNet = train.SingleConvNet(sp.FISTA,
in_dim=in_dim,
out_channels=width,
max_iter=max_iter,
tol=tol,
m=m)
# LipConvNet.mon.load_state_dict(torch.load('./FISTA_Test_model.params'))
Lip_train, Lip_val = train.train(trainLoader, testLoader,
LipConvNet,
max_lr=1e-3,
lr_mode='step',
step=lr_decay_steps,
change_mo=False,
epochs=epochs,
print_freq=100,
tune_alpha=False)
name = 'Lip_conv_w{:d}_L{:1.1f}'.format(width, gamma)
torch.save(LipConvNet.state_dict(), path + name + '.params')
LipConvNet.mon.tol = 1E-4
res = train.test_robustness(LipConvNet, testLoader)
io.savemat(path + name + ".mat", res)
print("fin")