-
Notifications
You must be signed in to change notification settings - Fork 6
/
train_smp
executable file
·93 lines (77 loc) · 3.48 KB
/
train_smp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python
import os
import datasets
from torch.utils.data import DataLoader
import torch
import segmentation_models_pytorch as smp
from argparse import ArgumentParser
def main():
parser = ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--dataset', type=str, default='Rellis3DImages')
parser.add_argument('--model', type=str, default='Unet')
parser.add_argument('--encoder', type=str, default='resnet34')
parser.add_argument('--encoder_weights', type=str, default='imagenet')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--img_size', nargs='+', type=int, default=(1184, 1920))
parser.add_argument('--n_epochs', type=int, default=100)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--n_workers', type=int, default=os.cpu_count())
parser.add_argument('--num_samples', type=int, default=None)
args = parser.parse_args()
Dataset = eval('datasets.%s' % args.dataset)
train_dataset = Dataset(crop_size=args.img_size, split='train', num_samples=args.num_samples)
valid_dataset = Dataset(crop_size=(1184, 1920), split='val', num_samples=args.num_samples)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=args.n_workers)
# create segmentation model with pretrained encoder
architecture = eval('smp.%s' % args.model)
model = architecture(
encoder_name=args.encoder,
encoder_weights=args.encoder_weights,
in_channels=3,
classes=len(train_dataset.CLASSES),
activation='sigmoid' if len(train_dataset.CLASSES) == 1 else 'softmax2d',
)
model = model.train()
# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
loss_fn = smp.utils.losses.DiceLoss(activation='softmax2d')
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index
metrics = [smp.utils.metrics.IoU(threshold=0.5)]
optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=args.lr)])
# create epoch runners
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss_fn,
metrics=metrics,
optimizer=optimizer,
device=args.device,
verbose=True,
)
valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss_fn,
metrics=metrics,
device=args.device,
verbose=True,
)
# train model
max_score = 0
for i in range(0, args.n_epochs):
print('\nEpoch: {}'.format(i))
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(valid_loader)
# do something (save model, change lr, etc.)
if max_score < valid_logs['iou_score']:
max_score = valid_logs['iou_score']
best_model_name = './%s_%s_%dx%d_lr%g_bs%d_epoch%d_%s_iou_%.2f.pth' %\
(args.model, args.encoder, args.img_size[0], args.img_size[1],
args.lr, args.batch_size, i, args.dataset, max_score)
torch.save(model, best_model_name)
print('Model %s saved!' % best_model_name)
if i == 25:
optimizer.param_groups[0]['lr'] = args.lr / 10.0
print('Decrease decoder learning rate!')
if __name__ == '__main__':
main()