-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_hard_neg_mining.py
141 lines (117 loc) · 3.99 KB
/
run_hard_neg_mining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
from skimage.io import imshow, imread
import pandas as pd
import keras.callbacks
import keras.utils
from model import get_model
from common import (
bound_gpu_usage,
get_rotated_subregion,
get_cropped_image,
augment_images
)
from predict_fcn import predict_with_model
NAME = '20_hard_negative_mining'
LR = 0.00002
BATCH_SIZE = 16
EPOCHS = 128
NUM_CLASSES = 2
IMAGE_SHAPE = (256, 256)
AUGMENT = False
STEPS = 192
STEPS_VAL = 64
model = None
train_df = pd.read_csv('datasets/train.tsv', sep='\t')
val_df = pd.read_csv('datasets/val.tsv', sep='\t')
train_true_masks = {
path: np.zeros(get_cropped_image(path).shape[:2], dtype=bool)
for path in train_df['path']
}
def write_mask(image):
i = 0
pattern = 'tmp/{}.png'
while os.path.exists(pattern.format(i)):
i += 1
path = pattern.format(i)
imsave(path, image)
def recalc_true_mask():
THRESHOLD = 0.5
new_dict = {}
global train_true_masks
wrote = False
for path in train_true_masks.keys():
mask_path = train_df[train_df['path'] == path]['mask'].iloc[0]
image = get_cropped_image(path)
gt = get_cropped_image(mask_path, as_grey=True) > 0
result = predict_with_model(model, image)[..., 0]
result = result > THRESHOLD
new_dict[path] = (result ^ gt).astype(bool)
if not wrote:
write_mask(new_dict[path])
wrote = True
train_true_masks = new_dict
def image_generator(df):
iteration = 0
while True:
iteration += 1
image_index = np.random.randint(0, len(df))
path = df['path'][image_index]
big_image = get_cropped_image(path)
mask = get_cropped_image(df['mask'][image_index], as_grey=True)
assert(len(mask.shape) == 2)
if path in train_true_masks:
false_prediction_positions = np.array(list(np.ndindex(mask.shape)))[
~train_true_masks[path].flatten()
]
indexes = np.random.choice(len(false_prediction_positions), BATCH_SIZE)
positions = false_prediction_positions[indexes]
shifts_x, shifts_y = zip(*positions)
else:
shifts_x = np.random.uniform(0, mask.shape[0], (BATCH_SIZE,))
shifts_y = np.random.uniform(0, mask.shape[1], (BATCH_SIZE,))
angles = np.random.uniform(0, np.pi * 2, (BATCH_SIZE,))
angles = np.random.uniform(0, np.pi * 2, (BATCH_SIZE,))
images = np.array([get_rotated_subregion(
big_image, IMAGE_SHAPE, angle, shift
) for angle, shift in zip(angles, zip(shifts_x, shifts_y))
])
if AUGMENT:
images = augment_images(images)
Ys = np.array([get_rotated_subregion(mask, IMAGE_SHAPE,
angle, shift) for angle, shift in zip(angles,
zip(shifts_x, shifts_y))])
if NUM_CLASSES > 2:
labels = keras.utils.to_categorical(
Ys.flatten(),
NUM_CLASSES
).reshape(BATCH_SIZE, IMAGE_SHAPE[0], IMAGE_SHAPE[1], NUM_CLASSES)
else:
labels = Ys.reshape(BATCH_SIZE, IMAGE_SHAPE[0], IMAGE_SHAPE[1])
yield images, labels
def fit():
global model
model = get_model(NUM_CLASSES, LR)
for epoch in range(EPOCHS):
model.fit_generator(
image_generator(train_df),
steps_per_epoch=STEPS,
validation_data=image_generator(val_df),
validation_steps=STEPS_VAL,
epochs=1,
initial_epoch=epoch,
callbacks=[
keras.callbacks.TensorBoard(
'logs/{}'.format(NAME),
write_images=False,
batch_size=BATCH_SIZE
),
keras.callbacks.ModelCheckpoint(
'models/{}.h5'.format(NAME), verbose=False,
save_best_only=True, monitor='val_loss'
)
]
)
recalc_true_mask()
if __name__ == '__main__':
bound_gpu_usage()
fit()