-
Notifications
You must be signed in to change notification settings - Fork 9
/
trainer.py
208 lines (165 loc) · 8.08 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import argparse
import tensorflow as tf
from tqdm import tqdm
from augmentor import Augmentor
from batchizer import Batchizer
from logger import Logger
from models import Simple, NASNET, Inception, GAP, YOLO
from utils import *
def create_model(session, m_type, m_name, logger):
"""
create or load the last saved model
:param session: tf.session
:param m_type: model type
:param m_name: model name (equal to folder name)
:param logger: logger
:return: None
"""
if m_type == "simple":
model = Simple(m_name, config, logger)
elif m_type == "YOLO":
model = YOLO(m_name, config, logger)
elif m_type == 'GAP':
model = GAP(m_name, config, logger)
elif m_type == 'NAS':
model = NASNET(m_name, config, logger)
elif m_type == 'INC':
model = Inception(m_name, config, logger)
else:
raise ValueError
ckpt = tf.train.get_checkpoint_state(model.model_dir)
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
logger.log('Reloading model parameters..')
model.restore(session, ckpt.model_checkpoint_path)
else:
logger.log('Created new model parameters..')
session.run(tf.global_variables_initializer())
return model
def print_predictions(result, logger):
"""
print predicted results every epoch
:param result: results
:param logger: logger
:return: None
"""
logger.log("########### Print Predictions ################")
logger.log("label: [\tx\t y\t w\t h\t a]")
for r in result:
y = r[0]
pred = r[1]
img_path = r[2]
logger.log("Path: " + img_path)
logger.log("truth: {0:2.2f} {1:2.2f} {2:2.2f}".format(y[0],
y[1],
y[2]))
logger.log("pred : {0:2.2f} {1:2.2f} {2:2.2f}\n".format(pred[0],
pred[1],
pred[2]))
def main(model_type, model_name, logger):
"""
train model until the maximum number of steps reached
:param model_type: model type
:param model_name: model name
:param logger: logger
:return: None
"""
with tf.Graph().as_default() as g:
with tf.Session() as sess:
# Create a new model or reload existing checkpoint
model = create_model(sess, model_type, model_name, logger)
# Create a log writer object
log_writer = tf.summary.FileWriter(model.model_dir, graph=sess.graph)
valid_loss = []
train_loss = []
# initial saver for
# 1. save every 3 epcohs
saver = tf.train.Saver(max_to_keep=3)
# 2. the best loss
best_saver = tf.train.Saver(max_to_keep=1)
# CSV files for train and test set
root_path = "data/"
train_csv = "train_data.csv"
valid_csv = "valid_data.csv"
train_path = os.path.join(root_path, train_csv)
valid_path = os.path.join(root_path, valid_csv)
# initial batchizer
train_batchizer = Batchizer(train_path, config["batch_size"])
valid_batchizer = Batchizer(valid_path, config["batch_size"])
# init augmentor only once for both train and validation set
ag = Augmentor('data/noisy_videos/', config)
train_batches = train_batchizer.batches(ag, config["output_dim"], num_c=config["input_channel"], zero_mean=True)
valid_batches = valid_batchizer.batches(ag, config["output_dim"], num_c=config["input_channel"], zero_mean=True)
while model.global_step.eval() < config["total_steps"]:
# get the learning rate from config file
lr_idx = int(model.global_step.eval() / config["decay_step"])
lr_idx = min(lr_idx, len(config["learning_rate"]) - 1)
lr = config["learning_rate"][lr_idx]
# train phase
with tqdm(total=config["validate_every"], unit="batch") as t:
for x, y, _ in train_batches:
if x is None:
continue
batch_loss, summary = model.train(sess, x, y, config["keep_prob"], lr)
train_loss.append(batch_loss)
t.set_description_str("batch_loss:{0:2.8f}, ".format(batch_loss))
log_writer.add_summary(summary, model.global_step.eval())
t.update(1)
if model.global_step.eval() % config["validate_every"] == 0:
break
# validation phase
valid_counter = 0
pred_result = []
with tqdm(total=config["validate_for"], unit="batch") as t:
for x, y, img in valid_batches:
if x is None:
continue
batch_loss, _, pred = model.eval(sess, x, y)
valid_loss.append(batch_loss)
t.set_description_str("batch_loss:{0:2.8f}".format(batch_loss))
valid_counter += 1
# select a random image from current batch and add it for visualization
# do it with a little chance! to reduce the size of output
if np.random.rand() > 0.95:
r = np.random.randint(0, high=len(x))
pred_result.append([y[r], pred[r], img[r]])
t.update(1)
if valid_counter == config["validate_for"]:
break
# print the results of validation dataset
print_predictions(pred_result, logger)
train_mean_loss = np.mean(train_loss)
valid_mean_loss = np.mean(valid_loss)
logger.log(
'Step:{0:6}: avg train loss:{1:2.8f}, avg validation loss:{2:2.8f}'.format(model.global_step.eval(), train_mean_loss, valid_mean_loss))
# save a checkpoint with the best loss value
if valid_mean_loss < logger.best_loss:
logger.save_best_loss(valid_mean_loss)
best_path = os.path.join(model.model_dir, "best_loss/")
check_dir(best_path)
save_path = best_saver.save(sess, best_path, global_step=model.global_step)
logger.log("model saved with best loss {0} at {1}".format(valid_mean_loss, save_path))
# save_every and validate_every should be dividable, otherwise this step will jump
if model.global_step.eval() % config["save_every"] == 0:
save_path = saver.save(sess, model.model_dir, global_step=model.global_step)
logger.log("model saved at {}".format(save_path))
summary = tf.Summary()
summary.value.add(tag="train_loss", simple_value=train_mean_loss)
summary.value.add(tag="valid_loss", simple_value=valid_mean_loss)
log_writer.add_summary(summary, model.global_step.eval())
# re-initializing lists
train_loss = []
valid_loss = []
logger.log('Training is done.')
if __name__ == "__main__":
class_ = argparse.ArgumentDefaultsHelpFormatter
parser = argparse.ArgumentParser(description=__doc__, formatter_class=class_)
parser.add_argument('model_name', help="name of saved model (3A4Bh-Ref25)")
parser.add_argument('--model_type', help="INC, YOLO, simple", default="INC")
parser.add_argument('--model_message', help="briefly explain your model", default="none")
args = parser.parse_args()
model_type = args.model_type
model_name = args.model_name
model_msg = args.model_message
logger = Logger(model_type, model_name, model_msg, config, dir="models/")
logger.log("Start training model...")
main(model_type, model_name, logger)