From 299033f8127d38dfa06b86a6cc798dd96fc16bea Mon Sep 17 00:00:00 2001 From: ryanxingql <34084019+ryanxingql@users.noreply.github.com> Date: Tue, 24 Sep 2024 16:27:02 +0800 Subject: [PATCH] add arch: CBDNet and UNet --- options/test/CBDNet/DIV2K_LMDB_G1_latest.yml | 51 ++++ options/train/CBDNet/DIV2K_LMDB_G1.yml | 96 ++++++++ powerqe/archs/__init__.py | 4 +- powerqe/archs/cbdnet_arch.py | 96 ++++++++ powerqe/archs/unet_arch.py | 243 +++++++++++++++++++ 5 files changed, 489 insertions(+), 1 deletion(-) create mode 100644 options/test/CBDNet/DIV2K_LMDB_G1_latest.yml create mode 100644 options/train/CBDNet/DIV2K_LMDB_G1.yml create mode 100644 powerqe/archs/cbdnet_arch.py create mode 100644 powerqe/archs/unet_arch.py diff --git a/options/test/CBDNet/DIV2K_LMDB_G1_latest.yml b/options/test/CBDNet/DIV2K_LMDB_G1_latest.yml new file mode 100644 index 0000000..beb88b5 --- /dev/null +++ b/options/test/CBDNet/DIV2K_LMDB_G1_latest.yml @@ -0,0 +1,51 @@ +# general settings +name: test_CBDNet_DIV2K_LMDB_G1_latest +model_type: SRModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset settings +datasets: + test: # multiple test datasets are acceptable + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: CBDNet + io_channels: 3 + estimate_channels: 32 + nlevel_denoise: 3 + nf_base_denoise: 64 + +# path +path: + pretrain_network_g: experiments/train_CBDNet_DIV2K_LMDB_G1/models/net_g_latest.pth + param_key_g: params_ema # load the ema model + strict_load_g: true + +# validation settings +val: + save_img: false + suffix: ~ # add suffix to saved images, if None, use exp name + + metrics: + psnr: + type: pyiqa + ssim: + type: pyiqa + lpips: + type: pyiqa + clipiqa+: + type: pyiqa + topiq_fr: + type: pyiqa + musiq: + type: pyiqa + wadiqam_fr: + type: pyiqa diff --git a/options/train/CBDNet/DIV2K_LMDB_G1.yml b/options/train/CBDNet/DIV2K_LMDB_G1.yml new file mode 100644 index 0000000..6895bb8 --- /dev/null +++ b/options/train/CBDNet/DIV2K_LMDB_G1.yml @@ -0,0 +1,96 @@ +# general settings +name: train_CBDNet_DIV2K_LMDB_G1 +model_type: QEModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset +# dataroot_gt: datasets/DIV2K/train +# dataroot_lq: datasets/DIV2K/train_BPG_QP37 +# io_backend: +# type: disk + dataroot_gt: datasets/DIV2K/train_size128_step64_thresh0.lmdb + dataroot_lq: datasets/DIV2K/train_BPG_QP37_size128_step64_thresh0.lmdb + io_backend: + type: lmdb + + gt_size: 128 # in accord with LMDB + use_hflip: true + use_rot: true + + # data loader + num_worker_per_gpu: 16 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: CBDNet + io_channels: 3 + estimate_channels: 32 + nlevel_denoise: 3 + nf_base_denoise: 64 + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: CosineAnnealingRestartLR + periods: [500000] + restart_weights: [1] + eta_min: !!float 1e-7 + + total_iter: 500000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 5e4 + save_img: false + + metrics: + psnr: + type: pyiqa + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/powerqe/archs/__init__.py b/powerqe/archs/__init__.py index c4b9c6e..ab73b0c 100644 --- a/powerqe/archs/__init__.py +++ b/powerqe/archs/__init__.py @@ -4,8 +4,10 @@ from .identitynet_arch import IdentityNet from .registry import ARCH_REGISTRY +from .cbdnet_arch import CBDNet +from .unet_arch import UNet -__all__ = ["build_network", "ARCH_REGISTRY", "IdentityNet"] +__all__ = ["build_network", "ARCH_REGISTRY", "IdentityNet", "CBDNet", "UNet"] def build_network(opt): diff --git a/powerqe/archs/cbdnet_arch.py b/powerqe/archs/cbdnet_arch.py new file mode 100644 index 0000000..d327b78 --- /dev/null +++ b/powerqe/archs/cbdnet_arch.py @@ -0,0 +1,96 @@ +import torch +from torch import nn as nn + +from .unet_arch import UNet +from .registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class CBDNet(nn.Module): + """CBDNet network structure. + + Args: + io_channels (int): Number of I/O channels. + estimate_channels (int): Channel number of the features in the estimation module. + nlevel_denoise (int): Level number of UNet for denoising. + nf_base_denoise (int): Base channel number of the features in the denoising module. + nf_gr_denoise (int): Growth rate of the channel number in the denoising module. + nl_base_denoise (int): Base convolution layer number in the denoising module. + nl_gr_denoise (int): Growth rate of the convolution layer number in the denoising module. + down_denoise (str): Downsampling method in the denoising module. + up_denoise (str): Upsampling method in the denoising module. + reduce_denoise (str): Reduction method for the guidance/feature maps in the denoising module. + """ + + def __init__( + self, + io_channels=3, + estimate_channels=32, + nlevel_denoise=3, + nf_base_denoise=64, + nf_gr_denoise=2, + nl_base_denoise=1, + nl_gr_denoise=2, + down_denoise="avepool2d", + up_denoise="transpose2d", + reduce_denoise="add", + ): + super().__init__() + + estimate_list = nn.ModuleList( + [ + nn.Conv2d( + in_channels=io_channels, + out_channels=estimate_channels, + kernel_size=3, + padding=3 // 2, + ), + nn.ReLU(inplace=True), + ] + ) + for _ in range(3): + estimate_list += nn.ModuleList( + [ + nn.Conv2d( + in_channels=estimate_channels, + out_channels=estimate_channels, + kernel_size=3, + padding=3 // 2, + ), + nn.ReLU(inplace=True), + ] + ) + estimate_list += nn.ModuleList( + [ + nn.Conv2d(estimate_channels, io_channels, 3, padding=3 // 2), + nn.ReLU(inplace=True), + ] + ) + self.estimate = nn.Sequential(*estimate_list) + + self.denoise = UNet( + nf_in=io_channels * 2, + nf_out=io_channels, + nlevel=nlevel_denoise, + nf_base=nf_base_denoise, + nf_gr=nf_gr_denoise, + nl_base=nl_base_denoise, + nl_gr=nl_gr_denoise, + down=down_denoise, + up=up_denoise, + reduce=reduce_denoise, + residual=False, + ) + + def forward(self, x): + """ + Args: + x (Tensor): Input tensor with the shape of (N, C, H, W). + + Returns: + Tensor + """ + estimated_noise_map = self.estimate(x) + res = self.denoise(torch.cat([x, estimated_noise_map], dim=1)) + out = res + x + return out diff --git a/powerqe/archs/unet_arch.py b/powerqe/archs/unet_arch.py new file mode 100644 index 0000000..731fa48 --- /dev/null +++ b/powerqe/archs/unet_arch.py @@ -0,0 +1,243 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .registry import ARCH_REGISTRY + + +class Up(nn.Module): + def __init__(self, method, nf_in=None): + super().__init__() + + supported_methods = ["upsample", "transpose2d"] + if method not in supported_methods: + raise NotImplementedError( + f'Upsampling method should be in "{supported_methods}";' + f' received "{method}".' + ) + + if method == "upsample": + self.up = nn.Upsample(scale_factor=2, mode="bicubic", align_corners=False) + elif method == "transpose2d": + self.up = nn.ConvTranspose2d( + in_channels=nf_in, + out_channels=nf_in // 2, + kernel_size=3, + stride=2, + padding=1, + ) + + def forward(self, inp_t, ref_big): + feat = self.up(inp_t) + + diff_h = ref_big.size()[2] - feat.size()[2] # (N, C, H, W); H + diff_w = ref_big.size()[3] - feat.size()[3] # W + + if diff_h < 0: + feat = feat[:, :, : ref_big.size()[2], :] + diff_h = 0 + if diff_w < 0: + feat = feat[:, :, :, : ref_big.size()[3]] + diff_w = 0 + + # only pad H and W; left (diff_w//2) + # right remaining (diff_w - diff_w//2) + # pad with constant 0 + out_t = F.pad( + input=feat, + pad=[ + diff_w // 2, + (diff_w - diff_w // 2), + diff_h // 2, + (diff_h - diff_h // 2), + ], + mode="constant", + value=0, + ) + + return out_t + + +@ARCH_REGISTRY.register() +class UNet(nn.Module): + def __init__( + self, + nf_in, + nf_out, + nlevel, + nf_base, + nf_max=1024, + nf_gr=2, + nl_base=1, + nl_max=8, + nl_gr=2, + down="avepool2d", + up="transpose2d", + reduce="concat", + residual=True, + ): + super().__init__() + + supported_up_methods = ["upsample", "transpose2d"] + if up not in supported_up_methods: + raise NotImplementedError( + f'Upsampling method should be in "{supported_up_methods}";' + f' received "{up}".' + ) + + supported_down_methods = ["avepool2d", "strideconv"] + if down not in supported_down_methods: + raise NotImplementedError( + f'Downsampling method should be in "{supported_down_methods}";' + f' received "{down}".' + ) + + supported_reduce_methods = ["add", "concat"] + if reduce not in supported_reduce_methods: + raise NotImplementedError( + f'Reduce method should be in "{supported_reduce_methods}";' + f' received "{reduce}".' + ) + + if residual and (nf_in != nf_out): + raise ValueError( + "The input channel number should be equal to the" + " output channel number." + ) + + self.nlevel = nlevel + self.reduce = reduce + self.residual = residual + + self.inc = nn.Sequential( + nn.Conv2d( + in_channels=nf_in, out_channels=nf_base, kernel_size=3, padding=1 + ), + nn.ReLU(inplace=True), + ) + + nf_lst = [nf_base] + nl_lst = [nl_base] + for idx_level in range(1, nlevel): + nf_new = nf_lst[-1] * nf_gr if (nf_lst[-1] * nf_gr) <= nf_max else nf_max + nf_lst.append(nf_new) + nl_new = nl_lst[-1] * nl_gr if ((nl_lst[-1] * nl_gr) <= nl_max) else nl_max + nl_lst.append(nl_new) + + # define downsampling operator + + if down == "avepool2d": + setattr(self, f"down_{idx_level}", nn.AvgPool2d(kernel_size=2)) + elif down == "strideconv": + setattr( + self, + f"down_{idx_level}", + nn.Sequential( + nn.Conv2d( + in_channels=nf_lst[-2], + out_channels=nf_lst[-2], + kernel_size=3, + stride=2, + padding=3 // 2, + ), + nn.ReLU(inplace=True), + ), + ) + + # define encoding operator + + module_lst = [ + nn.Conv2d( + in_channels=nf_lst[-2], + out_channels=nf_lst[-1], + kernel_size=3, + padding=1, + ), + nn.ReLU(inplace=True), + ] + for _ in range(nl_lst[-1]): + module_lst += [ + nn.Conv2d( + in_channels=nf_lst[-1], + out_channels=nf_lst[-1], + kernel_size=3, + padding=1, + ), + nn.ReLU(inplace=True), + ] + setattr(self, f"enc_{idx_level}", nn.Sequential(*module_lst)) + + for idx_level in range((nlevel - 2), -1, -1): + # define upsampling operator + setattr(self, f"up_{idx_level}", Up(nf_in=nf_lst[idx_level + 1], method=up)) + + # define decoding operator + + if reduce == "add": + module_lst = [ + nn.Conv2d( + in_channels=nf_lst[idx_level], + out_channels=nf_lst[idx_level], + kernel_size=3, + padding=1, + ), + nn.ReLU(inplace=True), + ] + else: + module_lst = [ + nn.Conv2d( + in_channels=nf_lst[idx_level + 1], + out_channels=nf_lst[idx_level], + kernel_size=3, + padding=1, + ), + nn.ReLU(inplace=True), + ] + for _ in range(nl_lst[idx_level]): + module_lst += [ + nn.Conv2d( + in_channels=nf_lst[idx_level], + out_channels=nf_lst[idx_level], + kernel_size=3, + padding=1, + ), + nn.ReLU(inplace=True), + ] + setattr(self, f"dec_{idx_level}", nn.Sequential(*module_lst)) + + self.outc = nn.Conv2d( + in_channels=nf_base, out_channels=nf_out, kernel_size=3, padding=1 + ) + + def forward(self, inp_t): + feat = self.inc(inp_t) + + # down + + map_lst = [] # guidance maps + for idx_level in range(1, self.nlevel): + map_lst.append(feat) # from level 0, 1, ..., (nlevel-1) + down = getattr(self, f"down_{idx_level}") + enc = getattr(self, f"enc_{idx_level}") + feat = enc(down(feat)) + + # up + + for idx_level in range((self.nlevel - 2), -1, -1): + up = getattr(self, f"up_{idx_level}") + dec = getattr(self, f"dec_{idx_level}") + g_map = map_lst[idx_level] + up_feat = up(inp_t=feat, ref_big=g_map) + + if self.reduce == "add": + feat = up_feat + g_map + elif self.reduce == "concat": + feat = torch.cat((up_feat, g_map), dim=1) + feat = dec(feat) + + out_t = self.outc(feat) + + if self.residual: + out_t += inp_t + + return out_t