-
Notifications
You must be signed in to change notification settings - Fork 6
/
model_trainer.py
252 lines (193 loc) · 9.36 KB
/
model_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
from data_loader import Div2kLoader
from models import edsr, srgan_discriminator
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.losses import MeanSquaredError, BinaryCrossentropy
import tensorflow as tf
class SrganTrainer():
"""
SRGAN trainer class.
Attributes
----------
generator: Model
The model used as a generator in SRGAN.
desicriminator: Model
The model used as a disrciminator in SRGAN.
data_path: str
The location of the training data. The gives path is expected to contain a 'HR' folder that holds high resolution images
and a 'LR' folder that holds low resolution images, the same image must have the same name in both folders so that both LR
and HR version are paired correctly
lrw: int
low resolution image width.
Default=64
lrh: int
low resolution image hight.
Default=64
load_all_data: bool
Whether to load the whole dataset in memory. Unless you have enough RAM in your machine, this is not recommended.
Default=False
learning_rate: int
specify the learning rate used during training.
Default=1e-4
Methods
-------
train_generator(epochs: int, starting_weights: str = None, batch_size: int = 32, loss: str = "mae")
Trains the generator model on its own.
train_gan(self, weights_path, steps, batch_size)
Starts the training of SRGAN.
"""
def __init__(self, generator: Model, discriminator: Model, data_path: str, lrw: int = 64, lrh: int = 64, load_all_data: bool = False, learning_rate: float = 1e-4):
# Input shape
self.channels = 3
self.lr_height = lrh
self.lr_width = lrw
self.lr_shape = (self.lr_height, self.lr_width, self.channels)
# Output shape
self.hr_height = self.lr_height*4
self.hr_width = self.lr_width*4
self.hr_shape = (self.hr_height, self.hr_width, self.channels)
self.generator = generator
self.gen_optimizer = Adam(learning_rate)
self.discriminator = discriminator
self.disc_optimizer = Adam(learning_rate)
self.data = Div2kLoader(data_path, load_all_data=load_all_data)
self.learning_rate = learning_rate
# Build the VGG network used in calculating the content Loss
self.vgg = self._vgg_54_model()
self._mean_squared_error = MeanSquaredError()
self._binary_cross_entropy = BinaryCrossentropy(from_logits=False)
def train_generator(self, epochs: int = 150, starting_weights: str = None, batch_size: int = 32, loss: str = "mae"):
"""Trains the generator model on its own. This is important before training SRGAN because the resulting weights
are going to be used as an initialization for the generator in SRGAN.
Parameters
----------
epochs: int
number of training epochs.
Default=150
starting_weights: str
path to initialization weights. If not specified, the generator will be initialized with random weights.
Default=None
batch_size: int
number of images per batch.
Default=32
loss: str
Training loss function, can be 'mae' for mean absolute error, or 'mse' for mean square error.
Default='mae'
Returns
-------
weights_path: str
Path of the resulting weights
"""
if starting_weights:
print(f"Initializing '{self.generator.name}' with {starting_weights}")
self.generator.load_weights(starting_weights)
optimizer = Adam(self.learning_rate, 0.9)
self.generator.compile(loss= [loss],
optimizer=optimizer)
self.data.batch_size = batch_size
#Where to save model weights:
weights_path = f"model_weights/generator_mse/{self.generator.name}_X4_MSE-{{epoch:02d}}.h5"
checkpoint = ModelCheckpoint(weights_path,
save_best_only=False)
print("Training the Generator on its own:")
self.generator.fit(self.data,
epochs=epochs,
callbacks=[checkpoint])
# serialize model to JSON
with open("model_json/{self.generator.name}_X4_MSE.json", "w") as json_file:
json_file.write(self.generator.to_json())
print(f"training '{self.generator.name}' model completed Successfully!")
return weights_path
def train_gan(self, weights_path: str, steps: int = 2e5, batch_size: int = 16):
"""Start the training of SRGAN.
Parameters
----------
weights_path: str
path to initialization weights.
steps: int
the number of training steps. At each step, the model is trained on a single batch of data.
Default=200,000
batch_size: int
number of images per batch.
Default=16
"""
# Prepare log file:
with open('training_history/losses.csv', 'w') as f:
f.write("step, perc_loss, disc_loss\n")
# Initialize the generator:
self.generator.load_weights(weights_path)
# Specify the batch size:
self.data.batch_size = batch_size
for step in range(1, steps + 1):
lr, hr = self.data.load_batch()
pl, dl = self._train_step(lr, hr)
print(f"Step #{step}:\n Generator loss = {pl}\n Discriminator loss = {dl}\n")
#Record losses in a csv log file:
with open('training_history/losses.csv', 'a') as f:
f.write(f"{step}, {pl}, {dl}\n")
#Save Weights every 200 steps
if step % 200 == 0:
discriminator.save_weights( f"model_weights/disc/{discriminator.name}_X4_SRGAN.h5")
generator.save_weights(f"model_weights/gen/{generator.name}_X4_SRGAN-{step}.h5")
print("#############\nWeights Saved\n#############\n")
@tf.function
def _train_step(self, lr, hr):
"""SRGAN training step.
Takes an LR and an HR image batch as input and returns
the computed perceptual loss and discriminator loss.
"""
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
lr = tf.cast(lr, tf.float32)
hr = tf.cast(hr, tf.float32)
# Forward pass
sr = self.generator(lr, training=True)
hr_output = self.discriminator(hr, training=True)
sr_output = self.discriminator(sr, training=True)
# Compute losses
con_loss = self._content_loss(hr, sr)
gen_loss = self._adversarial_loss(sr_output)
perc_loss = con_loss + 0.001 * gen_loss
disc_loss = self._discriminator_loss(hr_output, sr_output)
# Compute gradient of perceptual loss w.r.t. generator weights
gradients_of_generator = gen_tape.gradient(perc_loss, self.generator.trainable_variables)
# Compute gradient of discriminator loss w.r.t. discriminator weights
gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
# Update weights of generator and discriminator
self.gen_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
self.disc_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
return perc_loss, disc_loss
#loss_functions
def _vgg_54_model(self):
"""Creates a VGG model used in calculating the content loss.
Uses the 4th convolution before the 5th pooling layer as an output layer."""
vgg = VGG19(input_shape=self.hr_shape, include_top=False)
vgg = Model(vgg.input, vgg.layers[20].output)
return vgg
@tf.function
def _content_loss(self,original, generated):
generated_preprocessed = preprocess_input(generated)
original_preprocessed = preprocess_input(original)
sr_features = self.vgg(generated_preprocessed)/12.75
hr_features = self.vgg(original_preprocessed)/12.75
return self._mean_squared_error(hr_features, sr_features)
def _adversarial_loss(self, sr_out):
return self._binary_cross_entropy(tf.ones_like(sr_out), sr_out)
def _discriminator_loss(self, hr_out, sr_out):
hr_loss = self._binary_cross_entropy(tf.ones_like(hr_out), hr_out)
sr_loss = self._binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
return hr_loss + sr_loss
if __name__ == '__main__':
data_path = r"datasets/preprocessed_data/"
generator = edsr()
discriminator = srgan_discriminator()
gan = SrganTrainer(generator,
discriminator,
data_path=data_path,
load_all_data=False)
weights_path = gan.trainGenerator(epochs=150,
batch_size=32)
gan.train_gan(weights_path,
steps=2e5,
batch_size=16)