-
Notifications
You must be signed in to change notification settings - Fork 100
/
eeil.py
171 lines (151 loc) · 8.24 KB
/
eeil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
import warnings
from copy import deepcopy
from argparse import ArgumentParser
from torch.nn import functional as F
from torch.utils.data import DataLoader
from .incremental_learning import Inc_Learning_Appr
from datasets.exemplars_dataset import ExemplarsDataset
class Appr(Inc_Learning_Appr):
"""Class implementing the End-to-end Incremental Learning (EEIL) approach described in
http://openaccess.thecvf.com/content_ECCV_2018/papers/Francisco_M._Castro_End-to-End_Incremental_Learning_ECCV_2018_paper.pdf
Original code available at https://github.com/fmcp/EndToEndIncrementalLearning
Helpful code from https://github.com/arthurdouillard/incremental_learning.pytorch
"""
def __init__(self, model, device, nepochs=90, lr=0.1, lr_min=1e-6, lr_factor=10, lr_patience=5, clipgrad=10000,
momentum=0.9, wd=0.0001, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False,
eval_on_train=False, logger=None, exemplars_dataset=None, lamb=1.0, T=2, lr_finetuning_factor=0.1,
nepochs_finetuning=40, noise_grad=False):
super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
exemplars_dataset)
self.model_old = None
self.lamb = lamb
self.T = T
self.lr_finetuning_factor = lr_finetuning_factor
self.nepochs_finetuning = nepochs_finetuning
self.noise_grad = noise_grad
self._train_epoch = 0
self._finetuning_balanced = None
# EEIL is expected to be used with exemplars. If needed to be used without exemplars, overwrite here the
# `_get_optimizer` function with the one in LwF and update the criterion
have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class
if not have_exemplars:
warnings.warn("Warning: EEIL is expected to use exemplars. Check documentation.")
@staticmethod
def exemplars_dataset_class():
return ExemplarsDataset
@staticmethod
def extra_parser(args):
"""Returns a parser containing the approach specific parameters"""
parser = ArgumentParser()
# Added trade-off between the terms of Eq. 1 -- L = L_C + lamb * L_D
parser.add_argument('--lamb', default=1.0, type=float, required=False,
help='Forgetting-intransigence trade-off (default=%(default)s)')
# Page 6: "Based on our empirical results, we set T to 2 for all our experiments"
parser.add_argument('--T', default=2.0, type=float, required=False,
help='Temperature scaling (default=%(default)s)')
# "The same reduction is used in the case of fine-tuning, except that the starting rate is 0.01."
parser.add_argument('--lr-finetuning-factor', default=0.01, type=float, required=False,
help='Finetuning learning rate factor (default=%(default)s)')
# Number of epochs for balanced training
parser.add_argument('--nepochs-finetuning', default=40, type=int, required=False,
help='Number of epochs for balanced training (default=%(default)s)')
# the addition of noise to the gradients
parser.add_argument('--noise-grad', action='store_true',
help='Add noise to gradients (default=%(default)s)')
return parser.parse_known_args(args)
def _train_unbalanced(self, t, trn_loader, val_loader):
"""Unbalanced training"""
self._finetuning_balanced = False
self._train_epoch = 0
loader = self._get_train_loader(trn_loader, False)
super().train_loop(t, loader, val_loader)
return loader
def _train_balanced(self, t, trn_loader, val_loader):
"""Balanced finetuning"""
self._finetuning_balanced = True
self._train_epoch = 0
orig_lr = self.lr
self.lr *= self.lr_finetuning_factor
orig_nepochs = self.nepochs
self.nepochs = self.nepochs_finetuning
loader = self._get_train_loader(trn_loader, True)
super().train_loop(t, loader, val_loader)
self.lr = orig_lr
self.nepochs = orig_nepochs
def _get_train_loader(self, trn_loader, balanced=False):
"""Modify loader to be balanced or unbalanced"""
exemplars_ds = self.exemplars_dataset
trn_dataset = trn_loader.dataset
if balanced:
indices = torch.randperm(len(trn_dataset))
trn_dataset = torch.utils.data.Subset(trn_dataset, indices[:len(exemplars_ds)])
ds = exemplars_ds + trn_dataset
return DataLoader(ds, batch_size=trn_loader.batch_size,
shuffle=True,
num_workers=trn_loader.num_workers,
pin_memory=trn_loader.pin_memory)
def _noise_grad(self, parameters, iteration, eta=0.3, gamma=0.55):
"""Add noise to the gradients"""
parameters = list(filter(lambda p: p.grad is not None, parameters))
variance = eta / ((1 + iteration) ** gamma)
for p in parameters:
p.grad.add_(torch.randn(p.grad.shape, device=p.grad.device) * variance)
def train_loop(self, t, trn_loader, val_loader):
"""Contains the epochs loop"""
if t == 0: # First task is simple training
super().train_loop(t, trn_loader, val_loader)
loader = trn_loader
else:
# Page 4: "4. Incremental Learning" -- Only modification is that instead of preparing examplars before
# training, we do it online using the stored old model.
# Training process (new + old) - unbalanced training
loader = self._train_unbalanced(t, trn_loader, val_loader)
# Balanced fine-tunning (new + old)
self._train_balanced(t, trn_loader, val_loader)
# After task training: update exemplars
self.exemplars_dataset.collect_exemplars(self.model, loader, val_loader.dataset.transform)
def post_train_process(self, t, trn_loader):
"""Runs after training all the epochs of the task (after the train session)"""
# Save old model to extract features later
self.model_old = deepcopy(self.model)
self.model_old.eval()
self.model_old.freeze_all()
def train_epoch(self, t, trn_loader):
"""Runs a single epoch"""
self.model.train()
if self.fix_bn and t > 0:
self.model.freeze_bn()
for images, targets in trn_loader:
images = images.to(self.device)
# Forward old model
outputs_old = None
if t > 0:
outputs_old = self.model_old(images)
# Forward current model
outputs = self.model(images)
loss = self.criterion(t, outputs, targets.to(self.device), outputs_old)
# Backward
self.optimizer.zero_grad()
loss.backward()
# Page 8: "We apply L2-regularization and random noise [21] (with parameters eta = 0.3, gamma = 0.55)
# on the gradients to minimize overfitting"
# https://github.com/fmcp/EndToEndIncrementalLearning/blob/master/cnn_train_dag_exemplars.m#L367
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
if self.noise_grad:
self._noise_grad(self.model.parameters(), self._train_epoch)
self.optimizer.step()
self._train_epoch += 1
def criterion(self, t, outputs, targets, outputs_old=None):
"""Returns the loss value"""
# Classification loss for new classes
loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
# Distilation loss
if t > 0 and outputs_old:
# take into account current head when doing balanced finetuning
last_head_idx = t if self._finetuning_balanced else (t - 1)
for i in range(last_head_idx):
loss += self.lamb * F.binary_cross_entropy(F.softmax(outputs[i] / self.T, dim=1),
F.softmax(outputs_old[i] / self.T, dim=1))
return loss