forked from NAG-DevOps/openiss-reid-tfk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tripletloss.py
273 lines (205 loc) · 10.9 KB
/
tripletloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
'''
Main part of this code is from [1]
Some trivial modifications is done by Haotao Lai ([email protected])
In order to use this code to calcuate the batch triplet hard loss or batch triplet all
loss, the new sampling strategy which was published in the paper[2] must be applied.
It means the input embeddings (aka. feature maps) should follow a particular order that
all the image for the same id should be put together continuously.
[1] https://github.com/omoindrot/tensorflow-triplet-loss
[2] "In Defense of the TripletLoss of Person Re-Identification" https://arxiv.org/abs/1703.07737
'''
import tensorflow as tf
import numpy as np
import keras.backend as K
def triplet_loss(num_pid_per_batch, num_img_per_id, margin, type='hard'):
if type not in ('hard', 'all'):
raise Exception('unsupport triplet type: {}, should be \
one of (hard, all)'.format(type))
p = num_pid_per_batch
k = num_img_per_id
# construct labels
mask = [i for i in range(1, p + 1) for j in range(k)]
mask = np.asarray(mask)
labels = K.variable(value=mask)
def _pairwise_distances(embeddings, squared=False):
"""Compute the 2D matrix of distances between all the embeddings.
Args:
embeddings: tensor of shape (batch_size, embed_dim)
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
pairwise_distances: tensor of shape (batch_size, batch_size)
"""
# Get the dot product between all embeddings
# shape (batch_size, batch_size)
dot_product = tf.matmul(embeddings, tf.transpose(embeddings))
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
# shape (batch_size,)
square_norm = tf.diag_part(dot_product)
# Compute the pairwise distance matrix as we have:
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
# shape (batch_size, batch_size)
distances = tf.expand_dims(square_norm, 1) - 2.0 * \
dot_product + tf.expand_dims(square_norm, 0)
# Because of computation errors, some distances might be negative so we put everything >= 0.0
distances = tf.maximum(distances, 0.0)
if not squared:
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
# we need to add a small epsilon where distances == 0.0
mask = tf.to_float(tf.equal(distances, 0.0))
distances = distances + mask * 1e-16
distances = tf.sqrt(distances)
# Correct the epsilon added: set the distances on the mask to be exactly 0.0
distances = distances * (1.0 - mask)
return distances
def _get_anchor_positive_triplet_mask():
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
"""
# Check that i and j are distinct
indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
indices_not_equal = tf.logical_not(indices_equal)
# Check if labels[i] == labels[j]
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
labels_equal = tf.equal(tf.expand_dims(labels, 0),
tf.expand_dims(labels, 1))
# Combine the two masks
mask = tf.logical_and(indices_not_equal, labels_equal)
return mask
def _get_anchor_negative_triplet_mask():
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
"""
# Check if labels[i] != labels[k]
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
labels_equal = tf.equal(tf.expand_dims(labels, 0),
tf.expand_dims(labels, 1))
mask = tf.logical_not(labels_equal)
return mask
def _get_triplet_mask():
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
A triplet (i, j, k) is valid if:
- i, j, k are distinct
- labels[i] == labels[j] and labels[i] != labels[k]
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
"""
# Check that i, j and k are distinct
indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
indices_not_equal = tf.logical_not(indices_equal)
i_not_equal_j = tf.expand_dims(indices_not_equal, 2)
i_not_equal_k = tf.expand_dims(indices_not_equal, 1)
j_not_equal_k = tf.expand_dims(indices_not_equal, 0)
distinct_indices = tf.logical_and(tf.logical_and(
i_not_equal_j, i_not_equal_k), j_not_equal_k)
# Check if labels[i] == labels[j] and labels[i] != labels[k]
label_equal = tf.equal(tf.expand_dims(labels, 0),
tf.expand_dims(labels, 1))
i_equal_j = tf.expand_dims(label_equal, 2)
i_equal_k = tf.expand_dims(label_equal, 1)
valid_labels = tf.logical_and(i_equal_j, tf.logical_not(i_equal_k))
# Combine the two masks
mask = tf.logical_and(distinct_indices, valid_labels)
return mask
def batch_all_triplet_loss(y_true, embeddings):
"""Build the triplet loss over a batch of embeddings.
We generate all the valid triplets and average the loss over the positive ones.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = _pairwise_distances(embeddings, squared=True)
# shape (batch_size, batch_size, 1)
anchor_positive_dist = tf.expand_dims(pairwise_dist, 2)
assert anchor_positive_dist.shape[2] == 1, "{}".format(
anchor_positive_dist.shape)
# shape (batch_size, 1, batch_size)
anchor_negative_dist = tf.expand_dims(pairwise_dist, 1)
assert anchor_negative_dist.shape[1] == 1, "{}".format(
anchor_negative_dist.shape)
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
# and the 2nd (batch_size, 1, batch_size)
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
# Put to zero the invalid triplets
# (where label(a) != label(p) or label(n) == label(a) or a == p)
mask = _get_triplet_mask()
mask = tf.to_float(mask)
triplet_loss = tf.multiply(mask, triplet_loss)
# Remove negative losses (i.e. the easy triplets)
triplet_loss = tf.maximum(triplet_loss, 0.0)
# Count number of positive triplets (where triplet_loss > 0)
valid_triplets = tf.to_float(tf.greater(triplet_loss, 1e-16))
num_positive_triplets = tf.reduce_sum(valid_triplets)
num_valid_triplets = tf.reduce_sum(mask)
fraction_positive_triplets = num_positive_triplets / \
(num_valid_triplets + 1e-16)
# Get final mean triplet loss over the positive valid triplets
triplet_loss = tf.reduce_sum(triplet_loss) / \
(num_positive_triplets + 1e-16)
return triplet_loss
def batch_hard_triplet_loss(y_true, embeddings):
"""Build the triplet loss over a batch of embeddings.
For each anchor, we get the hardest positive and hardest negative to form a triplet.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = _pairwise_distances(embeddings, squared=True)
# For each anchor, get the hardest positive
# First, we need to get a mask for every valid positive (they should have same label)
mask_anchor_positive = _get_anchor_positive_triplet_mask()
mask_anchor_positive = tf.to_float(mask_anchor_positive)
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
anchor_positive_dist = tf.multiply(mask_anchor_positive, pairwise_dist)
# shape (batch_size, 1)
hardest_positive_dist = tf.reduce_max(
anchor_positive_dist, axis=1, keepdims=True)
tf.summary.scalar("hardest_positive_dist",
tf.reduce_mean(hardest_positive_dist))
# For each anchor, get the hardest negative
# First, we need to get a mask for every valid negative (they should have different labels)
mask_anchor_negative = _get_anchor_negative_triplet_mask()
mask_anchor_negative = tf.to_float(mask_anchor_negative)
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
max_anchor_negative_dist = tf.reduce_max(
pairwise_dist, axis=1, keepdims=True)
anchor_negative_dist = pairwise_dist + \
max_anchor_negative_dist * (1.0 - mask_anchor_negative)
# shape (batch_size,)
hardest_negative_dist = tf.reduce_min(
anchor_negative_dist, axis=1, keepdims=True)
tf.summary.scalar("hardest_negative_dist",
tf.reduce_mean(hardest_negative_dist))
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
triplet_loss = tf.maximum(
hardest_positive_dist - hardest_negative_dist + margin, 0.0)
# Get final mean triplet loss
triplet_loss = tf.reduce_mean(triplet_loss)
return triplet_loss
if type == 'hard':
return batch_hard_triplet_loss
elif type == 'all':
return batch_all_triplet_loss
else:
raise Exception('unsupport triplet type {}'.format(type))