-
Notifications
You must be signed in to change notification settings - Fork 122
/
test_image.py
executable file
·51 lines (43 loc) · 1.88 KB
/
test_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import os
from os import listdir
import numpy as np
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor
from tqdm import tqdm
from data_utils import is_image_file
from model import Net
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test Super Resolution')
parser.add_argument('--upscale_factor', default=3, type=int, help='super resolution upscale factor')
parser.add_argument('--model_name', default='epoch_3_100.pt', type=str, help='super resolution model name')
opt = parser.parse_args()
UPSCALE_FACTOR = opt.upscale_factor
MODEL_NAME = opt.model_name
path = 'data/test/SRF_' + str(UPSCALE_FACTOR) + '/data/'
images_name = [x for x in listdir(path) if is_image_file(x)]
model = Net(upscale_factor=UPSCALE_FACTOR)
if torch.cuda.is_available():
model = model.cuda()
model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
out_path = 'results/SRF_' + str(UPSCALE_FACTOR) + '/'
if not os.path.exists(out_path):
os.makedirs(out_path)
for image_name in tqdm(images_name, desc='convert LR images to HR images'):
img = Image.open(path + image_name).convert('YCbCr')
y, cb, cr = img.split()
image = Variable(ToTensor()(y)).view(1, -1, y.size[1], y.size[0])
if torch.cuda.is_available():
image = image.cuda()
out = model(image)
out = out.cpu()
out_img_y = out.data[0].numpy()
out_img_y *= 255.0
out_img_y = out_img_y.clip(0, 255)
out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')
out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')
out_img.save(out_path + image_name)