-
Notifications
You must be signed in to change notification settings - Fork 7
/
training.py
218 lines (179 loc) · 8.37 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import numpy as np
import math
import time
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from torch.nn import init
from torch.optim.optimizer import Optimizer
import os
import glob
import random
from tifffile import imread
from matplotlib import pyplot as plt
from tqdm import tqdm
from boilerplate import boilerplate
from models.lvae import LadderVAE
import lib.utils as utils
def train_network(model, lr, max_epochs,steps_per_epoch,train_loader, val_loader, test_loader,
virtual_batch, gaussian_noise_std, model_name,
test_log_every=1000, directory_path="./",
val_loss_patience=100, nrows=4, max_grad_norm=None):
"""Train Hierarchical DivNoising network.
Parameters
----------
model: Ladder VAE object
Hierarchical DivNoising model.
lr: float
Learning rate
max_epochs: int
Number of epochs to train the model for.
train_loader: PyTorch data loader
Data loader for training set.
val_loader: PyTorch data loader
Data loader for validation set.
test_loader: PyTorch data loader
Data loader for test set.
virtual_batch: int
Virtual batch size for training
gaussian_noise_std: float
standard deviation of gaussian noise (required when 'noiseModel' is None).
model_name: String
Name of Hierarchical DivNoising model with which to save weights.
test_log_every: int
Number of training steps after which one test evaluation is performed.
directory_path: String
Path where the DivNoising weights to be saved.
val_loss_patience: int
Number of epoochs after which training should be terminated if validation loss doesn't improve by 1e-6.
max_grad_norm: float
Value to limit/clamp the gradients at.
"""
model_folder = directory_path+"model/"
img_folder = directory_path+"imgs/"
device = model.device
optimizer, scheduler = boilerplate._make_optimizer_and_scheduler(model,lr,0.0)
loss_train_history = []
reconstruction_loss_train_history = []
kl_loss_train_history = []
loss_val_history = []
running_loss = 0.0
step_counter = 0
epoch = 0
patience_ = 0
first_step = True
try:
os.makedirs(model_folder)
except FileExistsError:
# directory already exists
pass
try:
os.makedirs(img_folder)
except FileExistsError:
# directory already exists
pass
seconds_last = time.time()
while step_counter / steps_per_epoch < max_epochs:
epoch = epoch+1
running_training_loss = []
running_reconstruction_loss = []
running_kl_loss = []
for batch_idx, (x, y) in enumerate(train_loader):
step_counter=batch_idx
x = x.unsqueeze(1) # Remove for RGB
x = x.to(device=device, dtype=torch.float)
step = model.global_step
if(test_log_every > 0):
if step % test_log_every == 0:
print("Testing the model at " "step {}". format(step))
with torch.no_grad():
boilerplate._test(epoch, img_folder, device, model,
test_loader, gaussian_noise_std,
model.data_std, nrows)
model.train()
optimizer.zero_grad()
### Make smaller batches
virtual_batches = torch.split(x,virtual_batch,0)
for batch in virtual_batches:
outputs = boilerplate.forward_pass(batch, batch, device, model,
gaussian_noise_std)
recons_loss = outputs['recons_loss']
kl_loss = outputs['kl_loss']
loss = recons_loss + kl_loss
loss.backward()
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
# Optimization step
running_training_loss.append(loss.item())
running_reconstruction_loss.append(recons_loss.item())
running_kl_loss.append(kl_loss.item())
optimizer.step()
model.increment_global_step()
first_step = False
if step_counter % steps_per_epoch == steps_per_epoch-1:
### Print training losses
to_print = "Epoch[{}/{}] Training Loss: {:.3f} Reconstruction Loss: {:.3f} KL Loss: {:.3f}"
to_print = to_print.format(epoch,
max_epochs,
np.mean(running_training_loss),
np.mean(running_reconstruction_loss),
np.mean(running_kl_loss))
print(to_print)
print('saving',model_folder+model_name+"_last_vae.net")
torch.save(model, model_folder+model_name+"_last_vae.net")
### Save training losses
loss_train_history.append(np.mean(running_training_loss))
reconstruction_loss_train_history.append(np.mean(running_reconstruction_loss))
kl_loss_train_history.append(np.mean(running_kl_loss))
np.save(model_folder+"train_loss.npy", np.array(loss_train_history))
np.save(model_folder+"train_reco_loss.npy", np.array(reconstruction_loss_train_history))
np.save(model_folder+"train_kl_loss.npy", np.array(kl_loss_train_history))
### Validation step
running_validation_loss = []
model.eval()
with torch.no_grad():
for i, (x, y) in enumerate(val_loader):
x = x.unsqueeze(1) # Remove for RGB
x = x.to(device=device, dtype=torch.float)
val_outputs = boilerplate.forward_pass(x, y, device, model, gaussian_noise_std)
val_recons_loss = val_outputs['recons_loss']
val_kl_loss = val_outputs['kl_loss']
val_loss = val_recons_loss + val_kl_loss
running_validation_loss.append(val_loss)
model.train()
total_epoch_loss_val = torch.mean(torch.stack(running_validation_loss))
scheduler.step(total_epoch_loss_val)
### Save validation losses
loss_val_history.append(total_epoch_loss_val.item())
np.save(model_folder+"val_loss.npy", np.array(loss_val_history))
if total_epoch_loss_val.item() < 1e-6 + np.min(loss_val_history):
patience_ = 0
print('saving',model_folder+model_name+"_best_vae.net")
torch.save(model, model_folder+model_name+"_best_vae.net")
else:
patience_ +=1
print("Patience:", patience_,
"Validation Loss:", total_epoch_loss_val.item(),
"Min validation loss:", np.min(loss_val_history))
seconds=time.time()
secondsElapsed=np.float(seconds-seconds_last)
seconds_last=seconds
remainingEps=(max_epochs+1)-(epoch+1)
estRemainSeconds=(secondsElapsed)*(remainingEps)
estRemainSecondsInt=int(secondsElapsed)*(remainingEps)
print('Time for epoch: '+ str(int(secondsElapsed))+ 'seconds')
print('Est remaining time: '+
str(datetime.timedelta(seconds= estRemainSecondsInt)) +
' or ' +
str(estRemainSecondsInt)+
' seconds')
print("----------------------------------------", flush=True)
if patience_ == val_loss_patience:
# print("Employing early stopping, validation loss did not improve for 100 epochs !"
return
break