-
Notifications
You must be signed in to change notification settings - Fork 0
/
classifier_data.py
98 lines (74 loc) · 2.77 KB
/
classifier_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import glob
import time
import random
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
class BinaryClassificationImageDataset(Dataset):
def __init__(self, root, transformations, mode="train"):
self.transform = transforms.Compose(transformations)
self.files_A = sorted(glob.glob(os.path.join(root, mode, "A", "*.*")))
self.files_B = sorted(glob.glob(os.path.join(root, mode, "B", "*.*")))
self.length = len(self.files_A) + len(self.files_B)
def __getitem__(self, index):
if index < len(self.files_A):
img_path = self.files_A[index]
lbl = 0
else:
img_path = self.files_B[index - len(self.files_A)]
lbl = 1
img = Image.open(img_path)
return self.transform(img), lbl
def __len__(self):
return self.length
def tensor_to_image(tensor_image):
image = tensor_image.detach().to('cpu').numpy()
image = np.moveaxis(image, [0, 1, 2], [2, 0, 1])
image = (image + 1) / 2
image[image < 0] = 0
image[image > 1] = 1
return image
def show_sample(dataset, class_a_name, class_b_name):
"""Show 10 examples of the dataset and their labels.
Display the image and the label in one figure
"""
fig, axs = plt.subplots(2, 5, figsize=(15, 6))
for i, j in enumerate(random.sample(range(len(dataset)), 10)):
img, lbl = dataset[j]
ax = axs[i // 5, i % 5]
ax.imshow(tensor_to_image(img))
ax.set_title(class_a_name if lbl == 0 else class_b_name)
ax.axis("off")
plt.show()
if __name__ == '__main__':
img_size = 64
transformations = [
transforms.Resize(int(img_size * 1.12), Image.BICUBIC),
transforms.RandomCrop((img_size, img_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
train_dataset = BinaryClassificationImageDataset(
os.path.join(".", "datasets", "apple2orange64"),
transformations=transformations,
mode="train"
)
# measure mean time for loading an image from the dataset
total_time = 0.
for i in range(len(train_dataset)):
tic = time.time()
img, lbl = train_dataset[i]
toc = time.time()
total_time += toc - tic
mean_time = total_time / len(train_dataset)
print(f"Mean time for loading an image: {mean_time:.10f} sec")
show_sample(train_dataset, 'apple', 'orange')
# from torch.utils.data import DataLoader
# data_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# batch = next(iter(data_loader))
# print(f"{batch[0].size() = }")
# print(f"{batch[1].size() = }")