diff --git a/src/data/common.py b/src/data/common.py index 26170bcd..e879ea7c 100644 --- a/src/data/common.py +++ b/src/data/common.py @@ -5,9 +5,10 @@ import torch -def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): - ih, iw = args[0].shape[:2] +def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): + + ih, iw = args[0][0]['image'].shape[:2] if not input_large: p = scale if multi else 1 tp = p * patch_size @@ -23,10 +24,10 @@ def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): tx, ty = scale * ix, scale * iy else: tx, ty = ix, iy - + ret = [ - args[0][iy:iy + ip, ix:ix + ip, :], - *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] + args[0][0]['image'][iy:iy + ip, ix:ix + ip, :], + *[a[0]['image'][ty:ty + tp, tx:tx + tp, :] for a in args[1:]] ] return ret diff --git a/src/data/srdata.py b/src/data/srdata.py index b9109aad..bf7c7f4c 100644 --- a/src/data/srdata.py +++ b/src/data/srdata.py @@ -144,8 +144,9 @@ def get_patch(self, lr, hr): ) if not self.args.no_augment: lr, hr = common.augment(lr, hr) else: - ih, iw = lr.shape[:2] - hr = hr[0:ih * scale, 0:iw * scale] + ih, iw = lr[0]['image'].shape[:2] + lr=lr[0]['image'] + hr = hr[0]['image'][0:ih * scale, 0:iw * scale] return lr, hr