forked from microsoft/robustlearn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
99 lines (76 loc) · 3.46 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import time
from alg.opt import *
from alg import alg, modelopera
from utils.util import set_random_seed, get_args, print_row, print_args, train_valid_target_eval_names, alg_loss_dict, print_environ
from datautil.getdataloader_single import get_act_dataloader
def main(args):
s = print_args(args, [])
set_random_seed(args.seed)
print_environ()
print(s)
if args.latent_domain_num < 6:
args.batch_size = 32*args.latent_domain_num
else:
args.batch_size = 16*args.latent_domain_num
train_loader, train_loader_noshuffle, valid_loader, target_loader, _, _, _ = get_act_dataloader(
args)
best_valid_acc, target_acc = 0, 0
algorithm_class = alg.get_algorithm_class(args.algorithm)
algorithm = algorithm_class(args).cuda()
algorithm.train()
optd = get_optimizer(algorithm, args, nettype='Diversify-adv')
opt = get_optimizer(algorithm, args, nettype='Diversify-cls')
opta = get_optimizer(algorithm, args, nettype='Diversify-all')
for round in range(args.max_epoch):
print(f'\n========ROUND {round}========')
print('====Feature update====')
loss_list = ['class']
print_row(['epoch']+[item+'_loss' for item in loss_list], colwidth=15)
for step in range(args.local_epoch):
for data in train_loader:
loss_result_dict = algorithm.update_a(data, opta)
print_row([step]+[loss_result_dict[item]
for item in loss_list], colwidth=15)
print('====Latent domain characterization====')
loss_list = ['total', 'dis', 'ent']
print_row(['epoch']+[item+'_loss' for item in loss_list], colwidth=15)
for step in range(args.local_epoch):
for data in train_loader:
loss_result_dict = algorithm.update_d(data, optd)
print_row([step]+[loss_result_dict[item]
for item in loss_list], colwidth=15)
algorithm.set_dlabel(train_loader)
print('====Domain-invariant feature learning====')
loss_list = alg_loss_dict(args)
eval_dict = train_valid_target_eval_names(args)
print_key = ['epoch']
print_key.extend([item+'_loss' for item in loss_list])
print_key.extend([item+'_acc' for item in eval_dict.keys()])
print_key.append('total_cost_time')
print_row(print_key, colwidth=15)
sss = time.time()
for step in range(args.local_epoch):
for data in train_loader:
step_vals = algorithm.update(data, opt)
results = {
'epoch': step,
}
results['train_acc'] = modelopera.accuracy(
algorithm, train_loader_noshuffle, None)
acc = modelopera.accuracy(algorithm, valid_loader, None)
results['valid_acc'] = acc
acc = modelopera.accuracy(algorithm, target_loader, None)
results['target_acc'] = acc
for key in loss_list:
results[key+'_loss'] = step_vals[key]
if results['valid_acc'] > best_valid_acc:
best_valid_acc = results['valid_acc']
target_acc = results['target_acc']
results['total_cost_time'] = time.time()-sss
print_row([results[key] for key in print_key], colwidth=15)
print(f'Target acc: {target_acc:.4f}')
if __name__ == '__main__':
args = get_args()
main(args)