From 2bc9157c82b8a1c76cc44cd11f928d5a264563b4 Mon Sep 17 00:00:00 2001 From: zywvvd <30398809+zywvvd@users.noreply.github.com> Date: Sun, 8 Sep 2019 20:52:46 +0800 Subject: [PATCH 1/2] Update commit.py My torch environment is python 3.7, pytorch 1.2, cuda 10.0.13, cudnn 7.6.4. Under this condition, the original codes will give an error. So I fix function get_patch in common.py and function get_patch in srdata.py to solve this problem. --- src/data/common.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/data/common.py b/src/data/common.py index 26170bcd..e879ea7c 100644 --- a/src/data/common.py +++ b/src/data/common.py @@ -5,9 +5,10 @@ import torch -def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): - ih, iw = args[0].shape[:2] +def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): + + ih, iw = args[0][0]['image'].shape[:2] if not input_large: p = scale if multi else 1 tp = p * patch_size @@ -23,10 +24,10 @@ def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): tx, ty = scale * ix, scale * iy else: tx, ty = ix, iy - + ret = [ - args[0][iy:iy + ip, ix:ix + ip, :], - *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] + args[0][0]['image'][iy:iy + ip, ix:ix + ip, :], + *[a[0]['image'][ty:ty + tp, tx:tx + tp, :] for a in args[1:]] ] return ret From 15df8109534a607895ba76c91ca9f73b6ab0c7f6 Mon Sep 17 00:00:00 2001 From: zywvvd <30398809+zywvvd@users.noreply.github.com> Date: Sun, 8 Sep 2019 20:56:52 +0800 Subject: [PATCH 2/2] Update srdata.py --- src/data/srdata.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/data/srdata.py b/src/data/srdata.py index b9109aad..bf7c7f4c 100644 --- a/src/data/srdata.py +++ b/src/data/srdata.py @@ -144,8 +144,9 @@ def get_patch(self, lr, hr): ) if not self.args.no_augment: lr, hr = common.augment(lr, hr) else: - ih, iw = lr.shape[:2] - hr = hr[0:ih * scale, 0:iw * scale] + ih, iw = lr[0]['image'].shape[:2] + lr=lr[0]['image'] + hr = hr[0]['image'][0:ih * scale, 0:iw * scale] return lr, hr