-
Notifications
You must be signed in to change notification settings - Fork 9
/
batchizer.py
88 lines (66 loc) · 2.64 KB
/
batchizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
from random import shuffle
import cv2
import numpy as np
from utils import change_channel, gray_normalizer
class Batchizer(object):
"""
list the images filename and read labels.csv,
shuffle them at each epoch and yield `batch_size` of images
"""
def __init__(self, data_path, batch_size):
self.batch_size = batch_size
# check if CSV files are exist
if not os.path.isfile(data_path):
raise FileNotFoundError
# load the records into memory
self.data_list = []
with open(data_path, "r") as f:
for line in f:
# values: [ img_path, x, y, w, h , a]
values = line.strip().split(",")
self.data_list.append([values[0], # image path
values[1], # x
values[2], # y
values[3], # w
values[4], # h
values[5]]) # a
self.n_batches = int(np.ceil(len(self.data_list) / self.batch_size))
def __len__(self):
return len(self.data_list)
def batches(self, ag, lbl_len=4, num_c=1,
zero_mean=False):
# infinitely do ....
while True:
# before each epoch, shuffle data
shuffle(self.data_list)
images = []
labels = []
img_names = []
# for all records in data list
for row in self.data_list:
# read the image and ground truth
image = cv2.imread(row[0], cv2.IMREAD_GRAYSCALE)
label = np.asarray(row[1:], dtype=np.float32)
# add noise to images and corresponding label
if ag is not None:
image, label = ag.addNoise(image, label)
# discard unused labels
label = label[0:lbl_len]
labels.append(label)
# zero mean the image
if zero_mean:
image = gray_normalizer(image)
# change to desired num_channel
image = change_channel(image, num_c)
images.append(image)
img_names.append(row[0])
if len(images) == self.batch_size:
yield images, labels, img_names
# empty the list for next yield
images = []
labels = []
img_names = []
# just yield reminded data
if len(images) > 0:
yield images, labels, img_names