-
Notifications
You must be signed in to change notification settings - Fork 10
/
evaluation_dataset.py
60 lines (49 loc) · 2.01 KB
/
evaluation_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import os
import numpy as np
import glob
import time
from config import Config
config = Config()
file_ext = ".jpg"
randomCrop = transforms.RandomCrop(config.input_size)
centerCrop = transforms.CenterCrop(config.input_size)
toTensor = transforms.ToTensor()
toPIL = transforms.ToPILImage()
# Assumes data directory/mode has a directory called "images" and one called "segmentations"
# Loads image as input, segmentation as output
# Transforms are specified in this file
class EvaluationDataset(Dataset):
def __init__(self, mode, type):
self.mode = mode # The "test" directory name
self.data_path = os.path.join(config.data_dir, mode)
self.images_dir = os.path.join(self.data_path, 'images_'+type)
self.seg_dir = os.path.join(self.data_path, 'segmentations_npy_'+type)
self.image_list = self.get_image_list()
def __len__(self):
return len(self.image_list)
def __getitem__(self, i):
# Get the ith item of the dataset
image_filepath, segmentation_filepath = self.image_list[i]
image = self.load_pil_image(image_filepath)
segmentation = self.load_segmentation(segmentation_filepath)
#print(image_filepath)
return toTensor(image), toTensor(segmentation), image_filepath
def get_image_list(self):
image_list = []
for file in os.listdir(self.images_dir):
if file.endswith(file_ext):
image_path = os.path.join(self.images_dir, file)
seg_path = os.path.join(self.seg_dir, file.split('.')[0]+'.npy')
image_list.append((image_path, seg_path))
return image_list
def load_pil_image(self, path):
# open path as file to avoid ResourceWarning
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def load_segmentation(self, path):
return np.load(path, allow_pickle=True)