-
Notifications
You must be signed in to change notification settings - Fork 68
/
vild.py
424 lines (345 loc) · 14.2 KB
/
vild.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
# Adapted from ViLD Colab notebook: https://colab.sandbox.google.com/github/tensorflow/tpu/blob/master/models/official/detection/projects/vild/ViLD_demo.ipynb
import time
import subprocess
from pathlib import Path
from easydict import EasyDict
import numpy as np
import torch
import clip
from tqdm import tqdm
import collections
import numpy as np
from PIL import Image
from scipy.special import softmax
import tensorflow.compat.v1 as tf
import cv2
for gpu in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, True)
FLAGS = {
'prompt_engineering': True,
'this_is': True,
'temperature': 100.0,
'use_softmax': False,
}
FLAGS = EasyDict(FLAGS)
def article(name):
return 'an' if name[0] in 'aeiou' else 'a'
def processed_name(name, rm_dot=False):
# _ for lvis
# / for obj365
res = name.replace('_', ' ').replace('/', ' or ').lower()
if rm_dot:
res = res.rstrip('.')
return res
single_template = [
'a photo of {article} {}.'
]
multiple_templates = [
'There is {article} {} in the scene.',
'There is the {} in the scene.',
'a photo of {article} {} in the scene.',
'a photo of the {} in the scene.',
'a photo of one {} in the scene.',
'itap of {article} {}.',
'itap of my {}.', # itap: I took a picture of
'itap of the {}.',
'a photo of {article} {}.',
'a photo of my {}.',
'a photo of the {}.',
'a photo of one {}.',
'a photo of many {}.',
'a good photo of {article} {}.',
'a good photo of the {}.',
'a bad photo of {article} {}.',
'a bad photo of the {}.',
'a photo of a nice {}.',
'a photo of the nice {}.',
'a photo of a cool {}.',
'a photo of the cool {}.',
'a photo of a weird {}.',
'a photo of the weird {}.',
'a photo of a small {}.',
'a photo of the small {}.',
'a photo of a large {}.',
'a photo of the large {}.',
'a photo of a clean {}.',
'a photo of the clean {}.',
'a photo of a dirty {}.',
'a photo of the dirty {}.',
'a bright photo of {article} {}.',
'a bright photo of the {}.',
'a dark photo of {article} {}.',
'a dark photo of the {}.',
'a photo of a hard to see {}.',
'a photo of the hard to see {}.',
'a low resolution photo of {article} {}.',
'a low resolution photo of the {}.',
'a cropped photo of {article} {}.',
'a cropped photo of the {}.',
'a close-up photo of {article} {}.',
'a close-up photo of the {}.',
'a jpeg corrupted photo of {article} {}.',
'a jpeg corrupted photo of the {}.',
'a blurry photo of {article} {}.',
'a blurry photo of the {}.',
'a pixelated photo of {article} {}.',
'a pixelated photo of the {}.',
'a black and white photo of the {}.',
'a black and white photo of {article} {}.',
'a plastic {}.',
'the plastic {}.',
'a toy {}.',
'the toy {}.',
'a plushie {}.',
'the plushie {}.',
'a cartoon {}.',
'the cartoon {}.',
'an embroidered {}.',
'the embroidered {}.',
'a painting of the {}.',
'a painting of a {}.',
]
clip.available_models()
model, _ = clip.load("ViT-B/32")
def build_text_embedding(categories):
if FLAGS.prompt_engineering:
templates = multiple_templates
else:
templates = single_template
run_on_gpu = torch.cuda.is_available()
with torch.no_grad():
all_text_embeddings = []
#print('Building text embeddings...')
for category in tqdm(categories):
texts = [
template.format(processed_name(category['name'], rm_dot=True),
article=article(category['name']))
for template in templates]
if FLAGS.this_is:
texts = [
'This is ' + text if text.startswith('a') or text.startswith('the') else text
for text in texts
]
texts = clip.tokenize(texts) #tokenize
if run_on_gpu:
texts = texts.cuda()
text_embeddings = model.encode_text(texts) #embed with text encoder
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
text_embedding = text_embeddings.mean(dim=0)
text_embedding /= text_embedding.norm()
all_text_embeddings.append(text_embedding)
all_text_embeddings = torch.stack(all_text_embeddings, dim=1)
if run_on_gpu:
all_text_embeddings = all_text_embeddings.cuda()
return all_text_embeddings.cpu().numpy().T
session = tf.Session(graph=tf.Graph())
if not Path('image_path_v2/saved_model.pb').exists():
subprocess.run(['gsutil', 'cp', '-r', 'gs://cloud-tpu-checkpoints/detection/projects/vild/colab/image_path_v2', './'], check=True)
saved_model_dir = './image_path_v2' #@param {type:"string"}
_ = tf.saved_model.loader.load(session, ['serve'], saved_model_dir)
def nms(dets, scores, thresh, max_dets=1000):
"""Non-maximum suppression.
Args:
dets: [N, 4]
scores: [N,]
thresh: iou threshold. Float
max_dets: int.
"""
y1 = dets[:, 0]
x1 = dets[:, 1]
y2 = dets[:, 2]
x2 = dets[:, 3]
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0 and len(keep) < max_dets:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
intersection = w * h
overlap = intersection / (areas[i] + areas[order[1:]] - intersection + 1e-12)
inds = np.where(overlap <= thresh)[0]
order = order[inds + 1]
return keep
def paste_instance_masks(masks,
detected_boxes,
image_height,
image_width):
"""Paste instance masks to generate the image segmentation results.
Args:
masks: a numpy array of shape [N, mask_height, mask_width] representing the
instance masks w.r.t. the `detected_boxes`.
detected_boxes: a numpy array of shape [N, 4] representing the reference
bounding boxes.
image_height: an integer representing the height of the image.
image_width: an integer representing the width of the image.
Returns:
segms: a numpy array of shape [N, image_height, image_width] representing
the instance masks *pasted* on the image canvas.
"""
def expand_boxes(boxes, scale):
"""Expands an array of boxes by a given scale."""
# Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/boxes.py#L227 # pylint: disable=line-too-long
# The `boxes` in the reference implementation is in [x1, y1, x2, y2] form,
# whereas `boxes` here is in [x1, y1, w, h] form
w_half = boxes[:, 2] * .5
h_half = boxes[:, 3] * .5
x_c = boxes[:, 0] + w_half
y_c = boxes[:, 1] + h_half
w_half *= scale
h_half *= scale
boxes_exp = np.zeros(boxes.shape)
boxes_exp[:, 0] = x_c - w_half
boxes_exp[:, 2] = x_c + w_half
boxes_exp[:, 1] = y_c - h_half
boxes_exp[:, 3] = y_c + h_half
return boxes_exp
# Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/test.py#L812 # pylint: disable=line-too-long
# To work around an issue with cv2.resize (it seems to automatically pad
# with repeated border values), we manually zero-pad the masks by 1 pixel
# prior to resizing back to the original image resolution. This prevents
# "top hat" artifacts. We therefore need to expand the reference boxes by an
# appropriate factor.
_, mask_height, mask_width = masks.shape
scale = max((mask_width + 2.0) / mask_width,
(mask_height + 2.0) / mask_height)
ref_boxes = expand_boxes(detected_boxes, scale)
ref_boxes = ref_boxes.astype(np.int32)
padded_mask = np.zeros((mask_height + 2, mask_width + 2), dtype=np.float32)
segms = []
for mask_ind, mask in enumerate(masks):
im_mask = np.zeros((image_height, image_width), dtype=np.uint8)
# Process mask inside bounding boxes.
padded_mask[1:-1, 1:-1] = mask[:, :]
ref_box = ref_boxes[mask_ind, :]
w = ref_box[2] - ref_box[0] + 1
h = ref_box[3] - ref_box[1] + 1
w = np.maximum(w, 1)
h = np.maximum(h, 1)
mask = cv2.resize(padded_mask, (w, h))
mask = np.array(mask > 0.5, dtype=np.uint8)
x_0 = min(max(ref_box[0], 0), image_width)
x_1 = min(max(ref_box[2] + 1, 0), image_width)
y_0 = min(max(ref_box[1], 0), image_height)
y_1 = min(max(ref_box[3] + 1, 0), image_height)
im_mask[y_0:y_1, x_0:x_1] = mask[
(y_0 - ref_box[1]):(y_1 - ref_box[1]),
(x_0 - ref_box[0]):(x_1 - ref_box[0])
]
segms.append(im_mask)
segms = np.array(segms)
assert masks.shape[0] == segms.shape[0]
return segms
#def main(image_path, category_name_string, params):
def main(image_path, text_features, params):
#################################################################
# Preprocessing categories and get params
#category_names = [x.strip() for x in category_name_string.split(';')]
#category_names = ['background'] + category_names
#categories = [{'name': item, 'id': idx+1,} for idx, item in enumerate(category_names)]
#category_indices = {cat['id']: cat for cat in categories}
#max_boxes_to_draw, nms_threshold, min_rpn_score_thresh, min_box_area = params
max_boxes_to_draw, nms_threshold, min_rpn_score_thresh, min_box_area, max_box_area = params
#################################################################
# Obtain results and read image
roi_boxes, roi_scores, detection_boxes, scores_unused, box_outputs, detection_masks, visual_features, image_info = session.run(
['RoiBoxes:0', 'RoiScores:0', '2ndStageBoxes:0', '2ndStageScoresUnused:0', 'BoxOutputs:0', 'MaskOutputs:0', 'VisualFeatOutputs:0', 'ImageInfo:0'],
feed_dict={'Placeholder:0': [image_path,]}) # 260 ms
roi_boxes = np.squeeze(roi_boxes, axis=0) # squeeze
# no need to clip the boxes, already done
roi_scores = np.squeeze(roi_scores, axis=0)
detection_boxes = np.squeeze(detection_boxes, axis=(0, 2))
scores_unused = np.squeeze(scores_unused, axis=0)
box_outputs = np.squeeze(box_outputs, axis=0)
detection_masks = np.squeeze(detection_masks, axis=0)
visual_features = np.squeeze(visual_features, axis=0)
image_info = np.squeeze(image_info, axis=0) # obtain image info
image_scale = np.tile(image_info[2:3, :], (1, 2))
image_height = int(image_info[0, 0])
image_width = int(image_info[0, 1])
rescaled_detection_boxes = detection_boxes / image_scale # rescale
# Read image
#image = np.asarray(Image.open(open(image_path, 'rb')).convert("RGB"))
#assert image_height == image.shape[0]
#assert image_width == image.shape[1]
#################################################################
# Filter boxes
# Apply non-maximum suppression to detected boxes with nms threshold.
nmsed_indices = nms(
detection_boxes,
roi_scores,
thresh=nms_threshold
)
# Compute RPN box size.
box_sizes = (rescaled_detection_boxes[:, 2] - rescaled_detection_boxes[:, 0]) * (rescaled_detection_boxes[:, 3] - rescaled_detection_boxes[:, 1])
# Filter out invalid rois (nmsed rois)
valid_indices = np.where(
np.logical_and(
np.isin(np.arange(len(roi_scores), dtype=np.int), nmsed_indices),
np.logical_and(
np.logical_not(np.all(roi_boxes == 0., axis=-1)),
np.logical_and(
roi_scores >= min_rpn_score_thresh,
#box_sizes > min_box_area
np.logical_and(box_sizes > min_box_area, box_sizes < max_box_area)
)
)
)
)[0]
#print('number of valid indices', len(valid_indices))
detection_roi_scores = roi_scores[valid_indices][:max_boxes_to_draw, ...]
detection_boxes = detection_boxes[valid_indices][:max_boxes_to_draw, ...]
detection_masks = detection_masks[valid_indices][:max_boxes_to_draw, ...]
detection_visual_feat = visual_features[valid_indices][:max_boxes_to_draw, ...]
rescaled_detection_boxes = rescaled_detection_boxes[valid_indices][:max_boxes_to_draw, ...]
#################################################################
# Compute text embeddings and detection scores, and rank results
#text_features = build_text_embedding(categories) # 380 ms
raw_scores = detection_visual_feat.dot(text_features.T)
if FLAGS.use_softmax:
scores_all = softmax(FLAGS.temperature * raw_scores, axis=-1)
else:
scores_all = raw_scores
indices = np.argsort(-np.max(scores_all, axis=1)) # Results are ranked by scores
indices_fg = np.array([i for i in indices if np.argmax(scores_all[i]) != 0])
#################################################################
# Plot detected boxes on the input image.
ymin, xmin, ymax, xmax = np.split(rescaled_detection_boxes, 4, axis=-1)
processed_boxes = np.concatenate([xmin, ymin, xmax - xmin, ymax - ymin], axis=-1)
segmentations = paste_instance_masks(detection_masks, processed_boxes, image_height, image_width) # 70 ms
return {
'boxes': [] if len(indices_fg) == 0 else rescaled_detection_boxes[indices_fg],
'masks': [] if len(indices_fg) == 0 else segmentations[indices_fg],
'scores': [] if len(indices_fg) == 0 else scores_all[indices_fg],
}
class VildDetector:
def __init__(self):
self.params = {
'max_boxes_to_draw': 20,
'nms_threshold': 0.6,
'min_rpn_score_thresh': 0.9,
}
self.text_features_cache = {}
def get_text_features(self, categories):
assert categories[0] == 'background'
categories = tuple(categories)
if categories not in self.text_features_cache:
self.text_features_cache[categories] = build_text_embedding([{'name': item, 'id': idx + 1} for idx, item in enumerate(categories)]) # 380 ms
return self.text_features_cache[categories]
def forward(self, image_path, categories, min_box_area=220, max_box_area=float('inf')):
categories = ['background'] + categories
text_features = self.get_text_features(categories)
params = self.params['max_boxes_to_draw'], self.params['nms_threshold'], self.params['min_rpn_score_thresh'], min_box_area, max_box_area
output = main(image_path, text_features, params)
if len(output['boxes']) > 0:
output['boxes'] = output['boxes'][:, [1, 0, 3, 2]]
output['categories'] = [categories[i] for i in output['scores'].argmax(axis=1)]
output['scores'] = output['scores'].max(axis=1)
else:
output['categories'] = []
return output