-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
114 lines (100 loc) · 4.24 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super(GeM, self).__init__()
self.p = nn.Parameter(torch.ones(1)*p)
self.eps = eps
def forward(self, x):
return self.gem(x, p=self.p, eps=self.eps)
def gem(self, x, p=3, eps=1e-6):
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
def __repr__(self):
return self.__class__.__name__ + \
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
', ' + 'eps=' + str(self.eps) + ')'
class ArcMarginProduct(nn.Module):
r"""Implement of large margin arc distance: :
Args:
in_features: size of each input sample
out_features: size of each output sample
s: norm of input feature
m: margin
cos(theta + m)
"""
def __init__(self, in_features, out_features, s=30.0,
m=0.50, easy_margin=False, ls_eps=0.0):
super(ArcMarginProduct, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.ls_eps = ls_eps # label smoothing
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
if self.ls_eps > 0:
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
# -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output
def criterion(outputs, labels):
return nn.CrossEntropyLoss()(outputs, labels)
class ArcFaceModel(nn.Module):
def __init__(self, model_name, pretrained=True, backbone=None):
super().__init__()
if backbone:
self.model = timm.create_model(model_name, pretrained=pretrained, num_classes = 0,in_chans=3)
self.model.load_state_dict(torch.load(backbone))
else:
self.model = timm.create_model(model_name, pretrained=pretrained, num_classes = 0,in_chans=3)
# in_features = self.model.classifier.in_features
self.model.classifier = nn.Identity()
self.model.global_pool = nn.Identity()
self.pooling = GeM()
self.embedding = nn.Sequential(
nn.LazyLinear(1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(1024,512),
)
self.fc = ArcMarginProduct(
# Embedding size
512,
# Number of output classes
23,
s=2,
m=1,
easy_margin=False,
ls_eps=0.0
)
def forward(self, images, labels):
features = self.model(images)
# pooled_features = self.model(images).flatten(1)
pooled_features = self.pooling(features).flatten(1)
embedding = self.embedding(pooled_features)
output = self.fc(embedding, labels)
return output
def extract(self, images):
features = self.model(images)
# pooled_features = self.model(images).flatten(1)
pooled_features = self.pooling(features).flatten(1)
embedding = self.embedding(pooled_features)
return embedding