forked from ChunML/ssd-tf2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_utils.py
executable file
·178 lines (146 loc) · 5.42 KB
/
image_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random
import numpy as np
import tensorflow as tf
from box_utils import compute_iou
class ImageVisualizer(object):
""" Class for visualizing image
Attributes:
idx_to_name: list to convert integer to string label
class_colors: colors for drawing boxes and labels
save_dir: directory to store images
"""
def __init__(self, idx_to_name, class_colors=None, save_dir=None):
self.idx_to_name = idx_to_name
if class_colors is None or len(class_colors) != len(self.idx_to_name):
self.class_colors = [[0, 255, 0]] * len(self.idx_to_name)
else:
self.class_colors = class_colors
if save_dir is None:
self.save_dir = './'
else:
self.save_dir = save_dir
os.makedirs(self.save_dir, exist_ok=True)
def save_image(self, img, boxes, labels, name):
""" Method to draw boxes and labels
then save to dir
Args:
img: numpy array (width, height, 3)
boxes: numpy array (num_boxes, 4)
labels: numpy array (num_boxes)
name: name of image to be saved
"""
plt.figure()
fig, ax = plt.subplots(1)
ax.imshow(img)
save_path = os.path.join(self.save_dir, name)
for i, box in enumerate(boxes):
idx = labels[i] - 1
cls_name = self.idx_to_name[idx]
top_left = (box[0], box[1])
bot_right = (box[2], box[3])
ax.add_patch(patches.Rectangle(
(box[0], box[1]),
box[2] - box[0], box[3] - box[1],
linewidth=2, edgecolor=(0., 1., 0.),
facecolor="none"))
plt.text(
box[0],
box[1],
s=cls_name,
color="white",
verticalalignment="top",
bbox={"color": (0., 1., 0.), "pad": 0},
)
plt.axis("off")
# plt.gca().xaxis.set_major_locator(NullLocator())
# plt.gca().yaxis.set_major_locator(NullLocator())
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.0)
plt.close('all')
def generate_patch(boxes, threshold):
""" Function to generate a random patch within the image
If the patch overlaps any gt boxes at above the threshold,
then the patch is picked, otherwise generate another patch
Args:
boxes: box tensor (num_boxes, 4)
threshold: iou threshold to decide whether to choose the patch
Returns:
patch: the picked patch
ious: an array to store IOUs of the patch and all gt boxes
"""
while True:
patch_w = random.uniform(0.1, 1)
scale = random.uniform(0.5, 2)
patch_h = patch_w * scale
patch_xmin = random.uniform(0, 1 - patch_w)
patch_ymin = random.uniform(0, 1 - patch_h)
patch_xmax = patch_xmin + patch_w
patch_ymax = patch_ymin + patch_h
patch = np.array(
[[patch_xmin, patch_ymin, patch_xmax, patch_ymax]],
dtype=np.float32)
patch = np.clip(patch, 0.0, 1.0)
ious = compute_iou(tf.constant(patch), boxes)
if tf.math.reduce_any(ious >= threshold):
break
return patch[0], ious[0]
def random_patching(img, boxes, labels):
""" Function to apply random patching
Firstly, a patch is randomly picked
Then only gt boxes of which IOU with the patch is above a threshold
and has center point lies within the patch will be selected
Args:
img: the original PIL Image
boxes: gt boxes tensor (num_boxes, 4)
labels: gt labels tensor (num_boxes,)
Returns:
img: the cropped PIL Image
boxes: selected gt boxes tensor (new_num_boxes, 4)
labels: selected gt labels tensor (new_num_boxes,)
"""
threshold = np.random.choice(np.linspace(0.1, 0.7, 4))
patch, ious = generate_patch(boxes, threshold)
box_centers = (boxes[:, :2] + boxes[:, 2:]) / 2
keep_idx = (
(ious > 0.3) &
(box_centers[:, 0] > patch[0]) &
(box_centers[:, 1] > patch[1]) &
(box_centers[:, 0] < patch[2]) &
(box_centers[:, 1] < patch[3])
)
if not tf.math.reduce_any(keep_idx):
return img, boxes, labels
img = img.crop(patch)
boxes = boxes[keep_idx]
patch_w = patch[2] - patch[0]
patch_h = patch[3] - patch[1]
boxes = tf.stack([
(boxes[:, 0] - patch[0]) / patch_w,
(boxes[:, 1] - patch[1]) / patch_h,
(boxes[:, 2] - patch[0]) / patch_w,
(boxes[:, 3] - patch[1]) / patch_h], axis=1)
boxes = tf.clip_by_value(boxes, 0.0, 1.0)
labels = labels[keep_idx]
return img, boxes, labels
def horizontal_flip(img, boxes, labels):
""" Function to horizontally flip the image
The gt boxes will be need to be modified accordingly
Args:
img: the original PIL Image
boxes: gt boxes tensor (num_boxes, 4)
labels: gt labels tensor (num_boxes,)
Returns:
img: the horizontally flipped PIL Image
boxes: horizontally flipped gt boxes tensor (num_boxes, 4)
labels: gt labels tensor (num_boxes,)
"""
img = img.transpose(Image.FLIP_LEFT_RIGHT)
boxes = tf.stack([
1 - boxes[:, 2],
boxes[:, 1],
1 - boxes[:, 0],
boxes[:, 3]], axis=1)
return img, boxes, labels