forked from PaddlePaddle/PaddleOCR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
distillation_loss.py
456 lines (401 loc) · 15.8 KB
/
distillation_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
import numpy as np
import cv2
from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .basic_loss import LossFromOutput
from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def _sum_loss(loss_dict):
if "loss" in loss_dict.keys():
return loss_dict
else:
loss_dict["loss"] = 0.
for k, value in loss_dict.items():
if k == "loss":
continue
else:
loss_dict["loss"] += value
return loss_dict
class DistillationDMLLoss(DMLLoss):
"""
"""
def __init__(self,
model_name_pairs=[],
act=None,
use_log=False,
key=None,
multi_head=False,
dis_head='ctc',
maps_name=None,
name="dml"):
super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list)
self.key = key
self.multi_head = multi_head
self.dis_head = dis_head
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = self._check_maps_name(maps_name)
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def _check_maps_name(self, maps_name):
if maps_name is None:
return None
elif type(maps_name) == str:
return [maps_name]
elif type(maps_name) == list:
return [maps_name]
else:
return None
def _slice_out(self, outs):
new_outs = {}
for k in self.maps_name:
if k == "thrink_maps":
new_outs[k] = outs[:, 0, :, :]
elif k == "threshold_maps":
new_outs[k] = outs[:, 1, :, :]
elif k == "binary_maps":
new_outs[k] = outs[:, 2, :, :]
else:
continue
return new_outs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if self.maps_name is None:
if self.multi_head:
loss = super().forward(out1[self.dis_head],
out2[self.dis_head])
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
outs1 = self._slice_out(out1)
outs2 = self._slice_out(out2)
for _c, k in enumerate(outs1.keys()):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], self.maps_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
_c], idx)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationCTCLoss(CTCLoss):
def __init__(self,
model_name_list=[],
key=None,
multi_head=False,
name="loss_ctc"):
super().__init__()
self.model_name_list = model_name_list
self.key = key
self.name = name
self.multi_head = multi_head
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
if self.multi_head:
assert 'ctc' in out, 'multi head has multi out'
loss = super().forward(out['ctc'], batch[:2] + batch[3:])
else:
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, model_name,
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationSARLoss(SARLoss):
def __init__(self,
model_name_list=[],
key=None,
multi_head=False,
name="loss_sar",
**kwargs):
ignore_index = kwargs.get('ignore_index', 92)
super().__init__(ignore_index=ignore_index)
self.model_name_list = model_name_list
self.key = key
self.name = name
self.multi_head = multi_head
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
if self.multi_head:
assert 'sar' in out, 'multi head has multi out'
loss = super().forward(out['sar'], batch[:1] + batch[2:])
else:
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, model_name,
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationDBLoss(DBLoss):
def __init__(self,
model_name_list=[],
balance_loss=True,
main_loss_type='DiceLoss',
alpha=5,
beta=10,
ohem_ratio=3,
eps=1e-6,
name="db",
**kwargs):
super().__init__()
self.model_name_list = model_name_list
self.name = name
self.key = None
def forward(self, predicts, batch):
loss_dict = {}
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss.keys():
if key == "loss":
continue
name = "{}_{}_{}".format(self.name, model_name, key)
loss_dict[name] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationDilaDBLoss(DBLoss):
def __init__(self,
model_name_pairs=[],
key=None,
balance_loss=True,
main_loss_type='DiceLoss',
alpha=5,
beta=10,
ohem_ratio=3,
eps=1e-6,
name="dila_dbloss"):
super().__init__()
self.model_name_pairs = model_name_pairs
self.name = name
self.key = key
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
stu_outs = predicts[pair[0]]
tch_outs = predicts[pair[1]]
if self.key is not None:
stu_preds = stu_outs[self.key]
tch_preds = tch_outs[self.key]
stu_shrink_maps = stu_preds[:, 0, :, :]
stu_binary_maps = stu_preds[:, 2, :, :]
# dilation to teacher prediction
dilation_w = np.array([[1, 1], [1, 1]])
th_shrink_maps = tch_preds[:, 0, :, :]
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
for i in range(th_shrink_maps.shape[0]):
dilate_maps[i] = cv2.dilate(
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
th_shrink_maps = paddle.to_tensor(dilate_maps)
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
1:]
# calculate the shrink map loss
bce_loss = self.alpha * self.bce_loss(
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
label_shrink_mask)
# k = f"{self.name}_{pair[0]}_{pair[1]}"
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
loss_dict[k] = bce_loss + loss_binary_maps
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationDistanceLoss(DistanceLoss):
"""
"""
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_l2"
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
key]
else:
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):
def __init__(self,
num_classes,
model_name_list=[],
key=None,
name="loss_ser"):
super().__init__(num_classes=num_classes)
self.model_name_list = model_name_list
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
return loss_dict
class DistillationLossFromOutput(LossFromOutput):
def __init__(self,
reduction="none",
model_name_list=[],
dist_key=None,
key="loss",
name="loss_re"):
super().__init__(key=key, reduction=reduction)
self.model_name_list = model_name_list
self.name = name
self.dist_key = dist_key
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.dist_key is not None:
out = out[self.dist_key]
loss = super().forward(out, batch)
loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
return loss_dict
class DistillationSERDMLLoss(DMLLoss):
"""
"""
def __init__(self,
act="softmax",
use_log=True,
num_classes=7,
model_name_pairs=[],
key=None,
name="loss_dml_ser"):
super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list)
self.key = key
self.name = name
self.num_classes = num_classes
self.model_name_pairs = model_name_pairs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
out1 = out1.reshape([-1, out1.shape[-1]])
out2 = out2.reshape([-1, out2.shape[-1]])
attention_mask = batch[2]
if attention_mask is not None:
active_output = attention_mask.reshape([-1, ]) == 1
out1 = out1[active_output]
out2 = out2[active_output]
loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1,
out2)
return loss_dict
class DistillationVQADistanceLoss(DistanceLoss):
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_l2"
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
attention_mask = batch[2]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if attention_mask is not None:
max_len = attention_mask.shape[-1]
out1 = out1[:, :max_len]
out2 = out2[:, :max_len]
out1 = out1.reshape([-1, out1.shape[-1]])
out2 = out2.reshape([-1, out2.shape[-1]])
if attention_mask is not None:
active_output = attention_mask.reshape([-1, ]) == 1
out1 = out1[active_output]
out2 = out2[active_output]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}nohu_{}".format(self.name, key,
idx)] = loss[key]
else:
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict