-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
105 lines (85 loc) · 4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
import torch
import numpy as np
import cv2
def convert_rgb_to_y(img):
if type(img) == np.ndarray:
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
else:
raise Exception('Unknown Type', type(img))
def convert_rgb_to_ycbcr(img):
if type(img) == np.ndarray:
y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
return np.array([y, cb, cr]).transpose([1, 2, 0])
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
else:
raise Exception('Unknown Type', type(img))
def convert_ycbcr_to_rgb(img):
if type(img) == np.ndarray:
r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
return np.array([r, g, b]).transpose([1, 2, 0])
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
return torch.cat([r, g, b], 0).permute(1, 2, 0)
else:
raise Exception('Unknown Type', type(img))
def calc_psnr(img1, img2):
return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
def adjust_learning_rate(optimizer, shrink_factor):
"""
调整学习率.
:参数 optimizer: 需要调整的优化器
:参数 shrink_factor: 调整因子,范围在 (0, 1) 之间,用于乘上原学习率.
"""
print("\n调整学习率.")
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * shrink_factor
print("新的学习率为 %f\n" % (optimizer.param_groups[0]['lr'], ))
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def test():
# workdir = "D:\\workroom\\tools\\dataset\\IR-dataset\\eval\\"
workdir = "./sample/q10/"
img0 = torch.from_numpy(cv2.imread(workdir + "he-test0.jpg").astype(np.float32)) / 255.0
img1 = torch.from_numpy(cv2.imread(workdir + 'he-test0-q10-ir.jpg').astype(np.float32)) / 255.0
psnr = calc_psnr(img0, img1)
print(psnr)
img0 = torch.from_numpy(cv2.imread(workdir + "he-test0.jpg").astype(np.float32)) / 255.0
img1 = torch.from_numpy(cv2.imread(workdir + 'he-test0-q10-nir.jpg').astype(np.float32)) / 255.0
psnr = calc_psnr(img0, img1)
print(psnr)
img0 = torch.from_numpy(cv2.imread(workdir + "he-test0.jpg").astype(np.float32)) / 255.0
img1 = torch.from_numpy(cv2.imread(workdir + 'he-test0-q10.jpg').astype(np.float32)) / 255.0
psnr = calc_psnr(img0, img1)
print(psnr)
if __name__ == "__main__":
test()