forked from CoinCheung/pytorch-loss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
focal_loss.py
259 lines (224 loc) · 8.16 KB
/
focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
##
# version 1: use torch.autograd
class FocalLossV1(nn.Module):
def __init__(self,
alpha=0.25,
gamma=2,
reduction='mean',):
super(FocalLossV1, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.crit = nn.BCEWithLogitsLoss(reduction='none')
def forward(self, logits, label):
'''
Usage is same as nn.BCEWithLogits:
>>> criteria = FocalLossV1()
>>> logits = torch.randn(8, 19, 384, 384)
>>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
>>> loss = criteria(logits, lbs)
'''
probs = torch.sigmoid(logits)
coeff = torch.abs(label - probs).pow(self.gamma).neg()
log_probs = torch.where(logits >= 0,
F.softplus(logits, -1, 50),
logits - F.softplus(logits, 1, 50))
log_1_probs = torch.where(logits >= 0,
-logits + F.softplus(logits, -1, 50),
-F.softplus(logits, 1, 50))
loss = label * self.alpha * log_probs + (1. - label) * (1. - self.alpha) * log_1_probs
loss = loss * coeff
if self.reduction == 'mean':
loss = loss.mean()
if self.reduction == 'sum':
loss = loss.sum()
return loss
##
# version 2: user derived grad computation
class FocalSigmoidLossFuncV2(torch.autograd.Function):
'''
compute backward directly for better numeric stability
'''
@staticmethod
@amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, label, alpha, gamma):
# logits = logits.float()
probs = torch.sigmoid(logits)
coeff = (label - probs).abs_().pow_(gamma).neg_()
log_probs = torch.where(logits >= 0,
F.softplus(logits, -1, 50),
logits - F.softplus(logits, 1, 50))
log_1_probs = torch.where(logits >= 0,
-logits + F.softplus(logits, -1, 50),
-F.softplus(logits, 1, 50))
ce_term1 = log_probs.mul_(label).mul_(alpha)
ce_term2 = log_1_probs.mul_(1. - label).mul_(1. - alpha)
ce = ce_term1.add_(ce_term2)
loss = ce * coeff
ctx.vars = (coeff, probs, ce, label, gamma, alpha)
return loss
@staticmethod
@amp.custom_bwd
def backward(ctx, grad_output):
'''
compute gradient of focal loss
'''
(coeff, probs, ce, label, gamma, alpha) = ctx.vars
d_coeff = (label - probs).abs_().pow_(gamma - 1.).mul_(gamma)
d_coeff.mul_(probs).mul_(1. - probs)
d_coeff = torch.where(label < probs, d_coeff.neg(), d_coeff)
term1 = d_coeff.mul_(ce)
d_ce = label * alpha
d_ce.sub_(probs.mul_((label * alpha).mul_(2).add_(1).sub_(label).sub_(alpha)))
term2 = d_ce.mul(coeff)
grads = term1.add_(term2)
grads.mul_(grad_output)
return grads, None, None, None
class FocalLossV2(nn.Module):
def __init__(self,
alpha=0.25,
gamma=2,
reduction='mean'):
super(FocalLossV2, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, logits, label):
'''
Usage is same as nn.BCEWithLogits:
>>> criteria = FocalLossV2()
>>> logits = torch.randn(8, 19, 384, 384)
>>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
>>> loss = criteria(logits, lbs)
'''
loss = FocalSigmoidLossFuncV2.apply(logits, label, self.alpha, self.gamma)
if self.reduction == 'mean':
loss = loss.mean()
if self.reduction == 'sum':
loss = loss.sum()
return loss
##
# version 3: implement wit cpp/cuda to save memory and accelerate
import focal_cpp # import torch before import cpp extension
class FocalSigmoidLossFuncV3(torch.autograd.Function):
'''
use cpp/cuda to accelerate and shrink memory usage
'''
@staticmethod
@amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, labels, alpha, gamma):
# logits = logits.float()
loss = focal_cpp.focalloss_forward(logits, labels, gamma, alpha)
ctx.variables = logits, labels, alpha, gamma
return loss
@staticmethod
@amp.custom_bwd
def backward(ctx, grad_output):
'''
compute gradient of focal loss
'''
logits, labels, alpha, gamma = ctx.variables
grads = focal_cpp.focalloss_backward(grad_output, logits, labels, gamma, alpha)
return grads, None, None, None
class FocalLossV3(nn.Module):
'''
This use better formula to compute the gradient, which has better numeric stability. Also use cuda to shrink memory usage and accelerate.
'''
def __init__(self,
alpha=0.25,
gamma=2,
reduction='mean'):
super(FocalLossV3, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, logits, label):
'''
Usage is same as nn.BCEWithLogits:
>>> criteria = FocalLossV3()
>>> logits = torch.randn(8, 19, 384, 384)
>>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
>>> loss = criteria(logits, lbs)
'''
loss = FocalSigmoidLossFuncV3.apply(logits, label, self.alpha, self.gamma)
if self.reduction == 'mean':
loss = loss.mean()
if self.reduction == 'sum':
loss = loss.sum()
return loss
if __name__ == '__main__':
import torchvision
import torch
import numpy as np
import random
torch.manual_seed(15)
random.seed(15)
np.random.seed(15)
torch.backends.cudnn.deterministic = True
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
net = torchvision.models.resnet18(pretrained=False)
self.conv1 = net.conv1
self.bn1 = net.bn1
self.maxpool = net.maxpool
self.relu = net.relu
self.layer1 = net.layer1
self.layer2 = net.layer2
self.layer3 = net.layer3
self.layer4 = net.layer4
self.out = nn.Conv2d(512, 3, 3, 1, 1)
def forward(self, x):
feat = self.conv1(x)
feat = self.bn1(feat)
feat = self.relu(feat)
feat = self.maxpool(feat)
feat = self.layer1(feat)
feat = self.layer2(feat)
feat = self.layer3(feat)
feat = self.layer4(feat)
feat = self.out(feat)
out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True)
return out
net1 = Model()
net2 = Model()
net2.load_state_dict(net1.state_dict())
criteria1 = FocalLossV2()
criteria2 = FocalLossV3()
net1.cuda()
net2.cuda()
net1.train()
net2.train()
net1.double()
net2.double()
criteria1.cuda()
criteria2.cuda()
optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
bs = 16
for it in range(300000):
inten = torch.randn(bs, 3, 224, 244).cuda()
# lbs = torch.randint(0, 2, (bs, 3, 224, 244)).float().cuda()
lbs = torch.randn(bs, 3, 224, 244).sigmoid().cuda()
inten = inten.double()
lbs = lbs.double()
logits = net1(inten)
loss1 = criteria1(logits, lbs)
optim1.zero_grad()
loss1.backward()
optim1.step()
logits = net2(inten)
loss2 = criteria2(logits, lbs)
optim2.zero_grad()
loss2.backward()
optim2.step()
with torch.no_grad():
if (it+1) % 50 == 0:
print('iter: {}, ================='.format(it+1))
print('out.weight: ', torch.mean(torch.abs(net1.out.weight - net2.out.weight)).item())
print('conv1.weight: ', torch.mean(torch.abs(net1.conv1.weight - net2.conv1.weight)).item())
print('loss: ', loss1.item() - loss2.item())