diff --git a/options/test/ARCNN/DIV2K_LMDB_G1_latest.yml b/options/test/ARCNN/DIV2K_LMDB_G1_latest.yml new file mode 100644 index 0000000..484e47f --- /dev/null +++ b/options/test/ARCNN/DIV2K_LMDB_G1_latest.yml @@ -0,0 +1,51 @@ +# general settings +name: test_ARCNN_DIV2K_LMDB_G1_latest +model_type: SRModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset settings +datasets: + test: # multiple test datasets are acceptable + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: ARCNN + io_channels: 3 + mid_channels_1: 64 + mid_channels_2: 32 + mid_channels_3: 16 + +# path +path: + pretrain_network_g: experiments/train_ARCNN_DIV2K_LMDB_G1/models/net_g_latest.pth + param_key_g: params_ema # load the ema model + strict_load_g: true + +# validation settings +val: + save_img: false + suffix: ~ # add suffix to saved images, if None, use exp name + + metrics: + psnr: + type: pyiqa + ssim: + type: pyiqa + lpips: + type: pyiqa + clipiqa+: + type: pyiqa + topiq_fr: + type: pyiqa + musiq: + type: pyiqa + wadiqam_fr: + type: pyiqa diff --git a/options/test/DCAD/DIV2K_LMDB_G1_latest.yml b/options/test/DCAD/DIV2K_LMDB_G1_latest.yml new file mode 100644 index 0000000..e49666f --- /dev/null +++ b/options/test/DCAD/DIV2K_LMDB_G1_latest.yml @@ -0,0 +1,50 @@ +# general settings +name: test_DCAD_DIV2K_LMDB_G1_latest +model_type: SRModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset settings +datasets: + test: # multiple test datasets are acceptable + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: DCAD + io_channels: 3 + mid_channels: 64 + num_blocks: 8 + +# path +path: + pretrain_network_g: experiments/train_DCAD_DIV2K_LMDB_G1/models/net_g_latest.pth + param_key_g: params_ema # load the ema model + strict_load_g: true + +# validation settings +val: + save_img: false + suffix: ~ # add suffix to saved images, if None, use exp name + + metrics: + psnr: + type: pyiqa + ssim: + type: pyiqa + lpips: + type: pyiqa + clipiqa+: + type: pyiqa + topiq_fr: + type: pyiqa + musiq: + type: pyiqa + wadiqam_fr: + type: pyiqa diff --git a/options/test/DnCNN/DIV2K_LMDB_G1_latest.yml b/options/test/DnCNN/DIV2K_LMDB_G1_latest.yml new file mode 100644 index 0000000..6e5e507 --- /dev/null +++ b/options/test/DnCNN/DIV2K_LMDB_G1_latest.yml @@ -0,0 +1,50 @@ +# general settings +name: test_DnCNN_DIV2K_LMDB_G1_latest +model_type: SRModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset settings +datasets: + test: # multiple test datasets are acceptable + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: DnCNN + io_channels: 3 + mid_channels: 64 + num_blocks: 15 + +# path +path: + pretrain_network_g: experiments/train_DnCNN_DIV2K_LMDB_G1/models/net_g_latest.pth + param_key_g: params_ema # load the ema model + strict_load_g: true + +# validation settings +val: + save_img: false + suffix: ~ # add suffix to saved images, if None, use exp name + + metrics: + psnr: + type: pyiqa + ssim: + type: pyiqa + lpips: + type: pyiqa + clipiqa+: + type: pyiqa + topiq_fr: + type: pyiqa + musiq: + type: pyiqa + wadiqam_fr: + type: pyiqa diff --git a/options/test/MPRNet/DIV2K_LMDB_G1_latest.yml b/options/test/MPRNet/DIV2K_LMDB_G1_latest.yml new file mode 100644 index 0000000..08ccca2 --- /dev/null +++ b/options/test/MPRNet/DIV2K_LMDB_G1_latest.yml @@ -0,0 +1,52 @@ +# general settings +name: test_MPRNet_DIV2K_LMDB_G1_latest +model_type: SRModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset settings +datasets: + test: # multiple test datasets are acceptable + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: MPRNet + io_channels: 3 + n_feat: 16 + scale_unetfeats: 16 + scale_orsnetfeats: 16 + num_cab: 4 + +# path +path: + pretrain_network_g: experiments/train_MPRNet_DIV2K_LMDB_G1/models/net_g_latest.pth + param_key_g: params_ema # load the ema model + strict_load_g: true + +# validation settings +val: + save_img: false + suffix: ~ # add suffix to saved images, if None, use exp name + + metrics: + psnr: + type: pyiqa + ssim: + type: pyiqa + lpips: + type: pyiqa + clipiqa+: + type: pyiqa + topiq_fr: + type: pyiqa + musiq: + type: pyiqa + wadiqam_fr: + type: pyiqa diff --git a/options/test/RBQE/DIV2K_LMDB_G1_latest.yml b/options/test/RBQE/DIV2K_LMDB_G1_latest.yml new file mode 100644 index 0000000..a93094a --- /dev/null +++ b/options/test/RBQE/DIV2K_LMDB_G1_latest.yml @@ -0,0 +1,51 @@ +# general settings +name: test_RBQE_nonblind_DIV2K_LMDB_G1_latest +model_type: SRModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset settings +datasets: + test: # multiple test datasets are acceptable + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: RBQE + nf_io: 3 + nf_base: 32 + if_only_last_output: True +# comp_type: hevc + +# path +path: + pretrain_network_g: experiments/train_RBQE_nonblind_DIV2K_LMDB_G1/models/net_g_latest.pth + param_key_g: params_ema # load the ema model + strict_load_g: true + +# validation settings +val: + save_img: false + suffix: ~ # add suffix to saved images, if None, use exp name + + metrics: + psnr: + type: pyiqa + ssim: + type: pyiqa + lpips: + type: pyiqa + clipiqa+: + type: pyiqa + topiq_fr: + type: pyiqa + musiq: + type: pyiqa + wadiqam_fr: + type: pyiqa diff --git a/options/test/RDN/DIV2K_LMDB_G1_latest.yml b/options/test/RDN/DIV2K_LMDB_G1_latest.yml new file mode 100644 index 0000000..a389908 --- /dev/null +++ b/options/test/RDN/DIV2K_LMDB_G1_latest.yml @@ -0,0 +1,53 @@ +# general settings +name: test_RDN_DIV2K_LMDB_G1_latest +model_type: SRModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset settings +datasets: + test: # multiple test datasets are acceptable + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: RDN + rescale: 1 + io_channels: 3 + mid_channels: 32 + num_blocks: 4 + num_layers: 4 + channel_growth: 32 + +# path +path: + pretrain_network_g: experiments/train_RDN_DIV2K_LMDB_G1/models/net_g_latest.pth + param_key_g: params_ema # load the ema model + strict_load_g: true + +# validation settings +val: + save_img: false + suffix: ~ # add suffix to saved images, if None, use exp name + + metrics: + psnr: + type: pyiqa + ssim: + type: pyiqa + lpips: + type: pyiqa + clipiqa+: + type: pyiqa + topiq_fr: + type: pyiqa + musiq: + type: pyiqa + wadiqam_fr: + type: pyiqa diff --git a/options/train/ARCNN/DIV2K_LMDB_G1.yml b/options/train/ARCNN/DIV2K_LMDB_G1.yml new file mode 100644 index 0000000..08ca771 --- /dev/null +++ b/options/train/ARCNN/DIV2K_LMDB_G1.yml @@ -0,0 +1,96 @@ +# general settings +name: train_ARCNN_DIV2K_LMDB_G1 +model_type: QEModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset +# dataroot_gt: datasets/DIV2K/train +# dataroot_lq: datasets/DIV2K/train_BPG_QP37 +# io_backend: +# type: disk + dataroot_gt: datasets/DIV2K/train_size128_step64_thresh0.lmdb + dataroot_lq: datasets/DIV2K/train_BPG_QP37_size128_step64_thresh0.lmdb + io_backend: + type: lmdb + + gt_size: 128 # in accord with LMDB + use_hflip: true + use_rot: true + + # data loader + num_worker_per_gpu: 16 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: ARCNN + io_channels: 3 + mid_channels_1: 64 + mid_channels_2: 32 + mid_channels_3: 16 + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: CosineAnnealingRestartLR + periods: [500000] + restart_weights: [1] + eta_min: !!float 1e-7 + + total_iter: 500000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 5e4 + save_img: false + + metrics: + psnr: + type: pyiqa + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/DCAD/DIV2K_LMDB_G1.yml b/options/train/DCAD/DIV2K_LMDB_G1.yml new file mode 100644 index 0000000..37fb7b4 --- /dev/null +++ b/options/train/DCAD/DIV2K_LMDB_G1.yml @@ -0,0 +1,95 @@ +# general settings +name: train_DCAD_DIV2K_LMDB_G1 +model_type: QEModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset +# dataroot_gt: datasets/DIV2K/train +# dataroot_lq: datasets/DIV2K/train_BPG_QP37 +# io_backend: +# type: disk + dataroot_gt: datasets/DIV2K/train_size128_step64_thresh0.lmdb + dataroot_lq: datasets/DIV2K/train_BPG_QP37_size128_step64_thresh0.lmdb + io_backend: + type: lmdb + + gt_size: 128 # in accord with LMDB + use_hflip: true + use_rot: true + + # data loader + num_worker_per_gpu: 16 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: DCAD + io_channels: 3 + mid_channels: 64 + num_blocks: 8 + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: CosineAnnealingRestartLR + periods: [500000] + restart_weights: [1] + eta_min: !!float 1e-7 + + total_iter: 500000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 5e4 + save_img: false + + metrics: + psnr: + type: pyiqa + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/DnCNN/DIV2K_LMDB_G1.yml b/options/train/DnCNN/DIV2K_LMDB_G1.yml new file mode 100644 index 0000000..5423a19 --- /dev/null +++ b/options/train/DnCNN/DIV2K_LMDB_G1.yml @@ -0,0 +1,95 @@ +# general settings +name: train_DnCNN_DIV2K_LMDB_G1 +model_type: QEModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset +# dataroot_gt: datasets/DIV2K/train +# dataroot_lq: datasets/DIV2K/train_BPG_QP37 +# io_backend: +# type: disk + dataroot_gt: datasets/DIV2K/train_size128_step64_thresh0.lmdb + dataroot_lq: datasets/DIV2K/train_BPG_QP37_size128_step64_thresh0.lmdb + io_backend: + type: lmdb + + gt_size: 128 # in accord with LMDB + use_hflip: true + use_rot: true + + # data loader + num_worker_per_gpu: 16 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: DnCNN + io_channels: 3 + mid_channels: 64 + num_blocks: 15 + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: CosineAnnealingRestartLR + periods: [500000] + restart_weights: [1] + eta_min: !!float 1e-7 + + total_iter: 500000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 5e4 + save_img: false + + metrics: + psnr: + type: pyiqa + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/MPRNet/DIV2K_LMDB_G1.yml b/options/train/MPRNet/DIV2K_LMDB_G1.yml new file mode 100644 index 0000000..6153762 --- /dev/null +++ b/options/train/MPRNet/DIV2K_LMDB_G1.yml @@ -0,0 +1,97 @@ +# general settings +name: train_MPRNet_DIV2K_LMDB_G1 +model_type: QEModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset +# dataroot_gt: datasets/DIV2K/train +# dataroot_lq: datasets/DIV2K/train_BPG_QP37 +# io_backend: +# type: disk + dataroot_gt: datasets/DIV2K/train_size128_step64_thresh0.lmdb + dataroot_lq: datasets/DIV2K/train_BPG_QP37_size128_step64_thresh0.lmdb + io_backend: + type: lmdb + + gt_size: 128 # in accord with LMDB + use_hflip: true + use_rot: true + + # data loader + num_worker_per_gpu: 16 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: MPRNet + io_channels: 3 + n_feat: 16 + scale_unetfeats: 16 + scale_orsnetfeats: 16 + num_cab: 4 + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: CosineAnnealingRestartLR + periods: [500000] + restart_weights: [1] + eta_min: !!float 1e-7 + + total_iter: 500000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 5e4 + save_img: false + + metrics: + psnr: + type: pyiqa + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/RBQE/nonblind_DIV2K_LMDB_G1.yml b/options/train/RBQE/nonblind_DIV2K_LMDB_G1.yml new file mode 100644 index 0000000..f0e4203 --- /dev/null +++ b/options/train/RBQE/nonblind_DIV2K_LMDB_G1.yml @@ -0,0 +1,96 @@ +# general settings +name: train_RBQE_nonblind_DIV2K_LMDB_G1 +model_type: QEModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset +# dataroot_gt: datasets/DIV2K/train +# dataroot_lq: datasets/DIV2K/train_BPG_QP37 +# io_backend: +# type: disk + dataroot_gt: datasets/DIV2K/train_size128_step64_thresh0.lmdb + dataroot_lq: datasets/DIV2K/train_BPG_QP37_size128_step64_thresh0.lmdb + io_backend: + type: lmdb + + gt_size: 128 # in accord with LMDB + use_hflip: true + use_rot: true + + # data loader + num_worker_per_gpu: 16 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: RBQE + nf_io: 3 + nf_base: 32 + if_only_last_output: True +# comp_type: hevc + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: CosineAnnealingRestartLR + periods: [500000] + restart_weights: [1] + eta_min: !!float 1e-7 + + total_iter: 500000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 5e4 + save_img: false + + metrics: + psnr: + type: pyiqa + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/RDN/DIV2K_LMDB_G1.yml b/options/train/RDN/DIV2K_LMDB_G1.yml new file mode 100644 index 0000000..01260ff --- /dev/null +++ b/options/train/RDN/DIV2K_LMDB_G1.yml @@ -0,0 +1,98 @@ +# general settings +name: train_RDN_DIV2K_LMDB_G1 +model_type: QEModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset +# dataroot_gt: datasets/DIV2K/train +# dataroot_lq: datasets/DIV2K/train_BPG_QP37 +# io_backend: +# type: disk + dataroot_gt: datasets/DIV2K/train_size128_step64_thresh0.lmdb + dataroot_lq: datasets/DIV2K/train_BPG_QP37_size128_step64_thresh0.lmdb + io_backend: + type: lmdb + + gt_size: 128 # in accord with LMDB + use_hflip: true + use_rot: true + + # data loader + num_worker_per_gpu: 16 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/valid + dataroot_lq: datasets/DIV2K/valid_BPG_QP37 + io_backend: + type: disk + +# network structures +network_g: + type: RDN + rescale: 1 + io_channels: 3 + mid_channels: 32 + num_blocks: 4 + num_layers: 4 + channel_growth: 32 + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: CosineAnnealingRestartLR + periods: [500000] + restart_weights: [1] + eta_min: !!float 1e-7 + + total_iter: 500000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 5e4 + save_img: false + + metrics: + psnr: + type: pyiqa + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/powerqe/archs/__init__.py b/powerqe/archs/__init__.py index ab73b0c..aef0280 100644 --- a/powerqe/archs/__init__.py +++ b/powerqe/archs/__init__.py @@ -2,12 +2,30 @@ from basicsr.utils import get_root_logger +from .arcnn_arch import ARCNN +from .cbdnet_arch import CBDNet +from .dcad_arch import DCAD +from .dncnn_arch import DnCNN from .identitynet_arch import IdentityNet +from .mprnet_arch import MPRNet +from .rbqe_arch import RBQE +from .rdn_arch import RDN from .registry import ARCH_REGISTRY -from .cbdnet_arch import CBDNet from .unet_arch import UNet -__all__ = ["build_network", "ARCH_REGISTRY", "IdentityNet", "CBDNet", "UNet"] +__all__ = [ + "ARCNN", + "CBDNet", + "DCAD", + "DnCNN", + "IdentityNet", + "MPRNet", + "RBQE", + "RDN", + "build_network", + "ARCH_REGISTRY", + "UNet", +] def build_network(opt): diff --git a/powerqe/archs/arcnn_arch.py b/powerqe/archs/arcnn_arch.py new file mode 100644 index 0000000..20a1ea5 --- /dev/null +++ b/powerqe/archs/arcnn_arch.py @@ -0,0 +1,75 @@ +import torch.nn as nn + +from .registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class ARCNN(nn.Module): + """AR-CNN network structure. + + Args: + io_channels (int): Number of I/O channels. + mid_channels_1 (int): Channel number of the first intermediate + features. + mid_channels_2 (int): Channel number of the second intermediate + features. + mid_channels_3 (int): Channel number of the third intermediate + features. + in_kernel_size (int): Kernel size of the first convolution. + mid_kernel_size_1 (int): Kernel size of the first intermediate + convolution. + mid_kernel_size_2 (int): Kernel size of the second intermediate + convolution. + out_kernel_size (int): Kernel size of the last convolution. + """ + + def __init__( + self, + io_channels=3, + mid_channels_1=64, + mid_channels_2=32, + mid_channels_3=16, + in_kernel_size=9, + mid_kernel_size_1=7, + mid_kernel_size_2=1, + out_kernel_size=5, + ): + super().__init__() + + self.layers = nn.Sequential( + nn.Conv2d( + io_channels, mid_channels_1, in_kernel_size, padding=in_kernel_size // 2 + ), + nn.ReLU(inplace=False), + nn.Conv2d( + mid_channels_1, + mid_channels_2, + mid_kernel_size_1, + padding=mid_kernel_size_1 // 2, + ), + nn.ReLU(inplace=False), + nn.Conv2d( + mid_channels_2, + mid_channels_3, + mid_kernel_size_2, + padding=mid_kernel_size_2 // 2, + ), + nn.ReLU(inplace=False), + nn.Conv2d( + mid_channels_3, + io_channels, + out_kernel_size, + padding=out_kernel_size // 2, + ), + ) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with the shape of (N, C, H, W). + + Returns: + Tensor + """ + return self.layers(x) + x diff --git a/powerqe/archs/dcad_arch.py b/powerqe/archs/dcad_arch.py new file mode 100644 index 0000000..9859b99 --- /dev/null +++ b/powerqe/archs/dcad_arch.py @@ -0,0 +1,46 @@ +import torch.nn as nn + +from .registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class DCAD(nn.Module): + """DCAD network structure. + + Args: + io_channels (int): Number of I/O channels. + mid_channels (int): Channel number of intermediate features. + num_blocks (int): Block number in the trunk network. + """ + + def __init__(self, io_channels=3, mid_channels=64, num_blocks=8): + super().__init__() + + # input conv + layers = [nn.Conv2d(io_channels, mid_channels, 3, padding=1)] + + # body + for _ in range(num_blocks): + layers += [ + nn.ReLU(inplace=False), + nn.Conv2d(mid_channels, mid_channels, 3, padding=1), + ] + + # output conv + layers += [ + nn.ReLU(inplace=False), + nn.Conv2d(mid_channels, io_channels, 3, padding=1), + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with the shape of (N, C, H, W). + + Returns: + Tensor + """ + return self.layers(x) + x diff --git a/powerqe/archs/dncnn_arch.py b/powerqe/archs/dncnn_arch.py new file mode 100644 index 0000000..0030e8e --- /dev/null +++ b/powerqe/archs/dncnn_arch.py @@ -0,0 +1,59 @@ +import torch.nn as nn + +from .registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class DnCNN(nn.Module): + """DnCNN network structure. + + Momentum for nn.BatchNorm2d is 0.9 in + "https://github.com/cszn/KAIR/blob + /7e51c16c6f55ff94b59c218c2af8e6b49fe0668b/models/basicblock.py#L69", + but is 0.1 default in PyTorch. + + Args: + io_channels (int): Number of I/O channels. + mid_channels (int): Channel number of intermediate features. + num_blocks (int): Block number in the trunk network. + if_bn (bool): Whether to use BN layer. Default: False. + """ + + def __init__(self, io_channels=3, mid_channels=64, num_blocks=15, if_bn=False): + super().__init__() + + # input conv + layers = [nn.Conv2d(io_channels, mid_channels, 3, padding=1)] + + # body + for _ in range(num_blocks): + layers.append(nn.ReLU(inplace=True)) + if if_bn: + layers += [ + # bias is unnecessary and off due to the following BN + nn.Conv2d(mid_channels, mid_channels, 3, padding=1, bias=False), + nn.BatchNorm2d( + num_features=mid_channels, momentum=0.9, eps=1e-04, affine=True + ), + ] + else: + layers.append(nn.Conv2d(mid_channels, mid_channels, 3, padding=1)) + + # output conv + layers += [ + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, io_channels, 3, padding=1), + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with the shape of (N, C, H, W). + + Returns: + Tensor + """ + return self.layers(x) + x diff --git a/powerqe/archs/mprnet_arch.py b/powerqe/archs/mprnet_arch.py new file mode 100644 index 0000000..01d6c72 --- /dev/null +++ b/powerqe/archs/mprnet_arch.py @@ -0,0 +1,533 @@ +""" +Source: https://github.com/swz30/MPRNet/blob/main/Deblurring/MPRNet.py +""" + +import torch +import torch.nn as nn +import torch.nn.functional as nn_func + +from .registry import ARCH_REGISTRY + + +def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size, + padding=(kernel_size // 2), + bias=bias, + stride=stride, + ) + + +# Channel Attention Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16, bias=False): + super().__init__() + + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), + nn.Sigmoid(), + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +# Channel Attention Block (CAB) +class CAB(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, bias, act): + super().__init__() + + modules_body = [ + conv(n_feat, n_feat, kernel_size, bias=bias), + act, + conv(n_feat, n_feat, kernel_size, bias=bias), + ] + + self.CA = CALayer(n_feat, reduction, bias=bias) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res = self.CA(res) + res += x + return res + + +# Supervised Attention Module +class SAM(nn.Module): + def __init__(self, n_feat, kernel_size, bias): + super().__init__() + + self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) + self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) + self.conv3 = conv(3, n_feat, kernel_size, bias=bias) + + def forward(self, x, x_img): + x1 = self.conv1(x) + img = self.conv2(x) + x_img + x2 = torch.sigmoid(self.conv3(img)) + x1 = x1 * x2 + x1 = x1 + x + return x1, img + + +# U-Net + + +def pad_and_add(x, y): + x_pads = [0, 0, 0, 0] + y_pads = [0, 0, 0, 0] + + # h + diff = x.shape[2] - y.shape[2] + if diff > 0: + y_pads[2] = diff // 2 + y_pads[3] = diff - diff // 2 + elif diff < 0: + x_pads[2] = (-diff) // 2 + x_pads[3] = (-diff) - (-diff) // 2 + + # w + diff = x.shape[3] - y.shape[3] + if diff > 0: + y_pads[0] = diff // 2 + y_pads[1] = diff - diff // 2 + elif diff < 0: + x_pads[0] = (-diff) // 2 + x_pads[1] = (-diff) - (-diff) // 2 + + x = nn_func.pad(input=x, pad=x_pads, mode="constant", value=0) + y = nn_func.pad(input=y, pad=y_pads, mode="constant", value=0) + return x + y + + +class Encoder(nn.Module): + def __init__( + self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff + ): + super().__init__() + + self.encoder_level1 = [ + CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2) + ] + self.encoder_level2 = [ + CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) + for _ in range(2) + ] + self.encoder_level3 = [ + CAB( + n_feat + (scale_unetfeats * 2), + kernel_size, + reduction, + bias=bias, + act=act, + ) + for _ in range(2) + ] + + self.encoder_level1 = nn.Sequential(*self.encoder_level1) + self.encoder_level2 = nn.Sequential(*self.encoder_level2) + self.encoder_level3 = nn.Sequential(*self.encoder_level3) + + self.down12 = DownSample(n_feat, scale_unetfeats) + self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats) + + # Cross Stage Feature Fusion (CSFF) + if csff: + self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) + self.csff_enc2 = nn.Conv2d( + n_feat + scale_unetfeats, + n_feat + scale_unetfeats, + kernel_size=1, + bias=bias, + ) + self.csff_enc3 = nn.Conv2d( + n_feat + (scale_unetfeats * 2), + n_feat + (scale_unetfeats * 2), + kernel_size=1, + bias=bias, + ) + + self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) + self.csff_dec2 = nn.Conv2d( + n_feat + scale_unetfeats, + n_feat + scale_unetfeats, + kernel_size=1, + bias=bias, + ) + self.csff_dec3 = nn.Conv2d( + n_feat + (scale_unetfeats * 2), + n_feat + (scale_unetfeats * 2), + kernel_size=1, + bias=bias, + ) + + def forward(self, x, encoder_outs=None, decoder_outs=None): + enc1 = self.encoder_level1(x) + if (encoder_outs is not None) and (decoder_outs is not None): + enc1 = ( + enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) + ) + + x = self.down12(enc1) + + enc2 = self.encoder_level2(x) + if (encoder_outs is not None) and (decoder_outs is not None): + enc2 = ( + enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) + ) + + x = self.down23(enc2) + + enc3 = self.encoder_level3(x) + if (encoder_outs is not None) and (decoder_outs is not None): + enc3 = pad_and_add(enc3, self.csff_enc3(encoder_outs[2])) + enc3 = pad_and_add(enc3, self.csff_dec3(decoder_outs[2])) + + return [enc1, enc2, enc3] + + +class Decoder(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): + super().__init__() + + self.decoder_level1 = [ + CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2) + ] + self.decoder_level2 = [ + CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) + for _ in range(2) + ] + self.decoder_level3 = [ + CAB( + n_feat + (scale_unetfeats * 2), + kernel_size, + reduction, + bias=bias, + act=act, + ) + for _ in range(2) + ] + + self.decoder_level1 = nn.Sequential(*self.decoder_level1) + self.decoder_level2 = nn.Sequential(*self.decoder_level2) + self.decoder_level3 = nn.Sequential(*self.decoder_level3) + + self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) + self.skip_attn2 = CAB( + n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act + ) + + self.up21 = SkipUpSample(n_feat, scale_unetfeats) + self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats) + + def forward(self, outs): + enc1, enc2, enc3 = outs + dec3 = self.decoder_level3(enc3) + + x = self.up32(dec3, self.skip_attn2(enc2)) + dec2 = self.decoder_level2(x) + + x = self.up21(dec2, self.skip_attn1(enc1)) + dec1 = self.decoder_level1(x) + + return [dec1, dec2, dec3] + + +# Resizing Modules +class DownSample(nn.Module): + def __init__(self, in_channels, s_factor): + super().__init__() + + self.down = nn.Sequential( + nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=False), + nn.Conv2d( + in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False + ), + ) + + def forward(self, x): + x = self.down(x) + return x + + +class UpSample(nn.Module): + def __init__(self, in_channels, s_factor): + super().__init__() + + self.up = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Conv2d( + in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False + ), + ) + + def forward(self, x): + x = self.up(x) + return x + + +class SkipUpSample(nn.Module): + def __init__(self, in_channels, s_factor): + super().__init__() + + self.up = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Conv2d( + in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False + ), + ) + + def forward(self, x, y): + x = self.up(x) + return pad_and_add(x, y) + + +# Original Resolution Block (ORB) +class ORB(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab): + super().__init__() + + modules_body = [ + CAB(n_feat, kernel_size, reduction, bias=bias, act=act) + for _ in range(num_cab) + ] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class ORSNet(nn.Module): + def __init__( + self, + n_feat, + scale_orsnetfeats, + kernel_size, + reduction, + act, + bias, + scale_unetfeats, + num_cab, + ): + super().__init__() + + self.orb1 = ORB( + n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab + ) + self.orb2 = ORB( + n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab + ) + self.orb3 = ORB( + n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab + ) + + self.up_enc1 = UpSample(n_feat, scale_unetfeats) + self.up_dec1 = UpSample(n_feat, scale_unetfeats) + + self.up_enc2 = nn.Sequential( + UpSample(n_feat + scale_unetfeats, scale_unetfeats), + UpSample(n_feat, scale_unetfeats), + ) + self.up_dec2 = nn.Sequential( + UpSample(n_feat + scale_unetfeats, scale_unetfeats), + UpSample(n_feat, scale_unetfeats), + ) + + self.conv_enc1 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + self.conv_enc2 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + self.conv_enc3 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + + self.conv_dec1 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + self.conv_dec2 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + self.conv_dec3 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + + def forward(self, x, encoder_outs, decoder_outs): + x = self.orb1(x) + x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0]) + + x = self.orb2(x) + x = pad_and_add(x, self.conv_enc2(self.up_enc1(encoder_outs[1]))) + x = pad_and_add(x, self.conv_dec2(self.up_dec1(decoder_outs[1]))) + + x = self.orb3(x) + x = pad_and_add(x, self.conv_enc3(self.up_enc2(encoder_outs[2]))) + x = pad_and_add(x, self.conv_dec3(self.up_dec2(decoder_outs[2]))) + return x + + +@ARCH_REGISTRY.register() +class MPRNet(nn.Module): + def __init__( + self, + io_channels=3, + n_feat=96, + scale_unetfeats=48, + scale_orsnetfeats=32, + num_cab=8, + kernel_size=3, + reduction=4, + bias=False, + ): + super().__init__() + + act = nn.PReLU() + self.shallow_feat1 = nn.Sequential( + conv(io_channels, n_feat, kernel_size, bias=bias), + CAB(n_feat, kernel_size, reduction, bias=bias, act=act), + ) + self.shallow_feat2 = nn.Sequential( + conv(io_channels, n_feat, kernel_size, bias=bias), + CAB(n_feat, kernel_size, reduction, bias=bias, act=act), + ) + self.shallow_feat3 = nn.Sequential( + conv(io_channels, n_feat, kernel_size, bias=bias), + CAB(n_feat, kernel_size, reduction, bias=bias, act=act), + ) + + # Cross Stage Feature Fusion (CSFF) + self.stage1_encoder = Encoder( + n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False + ) + self.stage1_decoder = Decoder( + n_feat, kernel_size, reduction, act, bias, scale_unetfeats + ) + + self.stage2_encoder = Encoder( + n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True + ) + self.stage2_decoder = Decoder( + n_feat, kernel_size, reduction, act, bias, scale_unetfeats + ) + + self.stage3_orsnet = ORSNet( + n_feat, + scale_orsnetfeats, + kernel_size, + reduction, + act, + bias, + scale_unetfeats, + num_cab, + ) + + self.sam12 = SAM(n_feat, kernel_size=1, bias=bias) + self.sam23 = SAM(n_feat, kernel_size=1, bias=bias) + + self.concat12 = conv(n_feat * 2, n_feat, kernel_size, bias=bias) + self.concat23 = conv( + n_feat * 2, n_feat + scale_orsnetfeats, kernel_size, bias=bias + ) + self.tail = conv( + n_feat + scale_orsnetfeats, io_channels, kernel_size, bias=bias + ) + + def forward(self, x3_img): + # Original-resolution Image for Stage 3 + hgt = x3_img.size(2) + wdt = x3_img.size(3) + + # Multi-Patch Hierarchy: Split Image into four non-overlapping patches + + # Two Patches for Stage 2 + x2top_img = x3_img[:, :, 0 : int(hgt / 2), :] + x2bot_img = x3_img[:, :, int(hgt / 2) : hgt, :] + + # Four Patches for Stage 1 + x1ltop_img = x2top_img[:, :, :, 0 : int(wdt / 2)] + x1rtop_img = x2top_img[:, :, :, int(wdt / 2) : wdt] + x1lbot_img = x2bot_img[:, :, :, 0 : int(wdt / 2)] + x1rbot_img = x2bot_img[:, :, :, int(wdt / 2) : wdt] + + # Stage 1 + + # Compute Shallow Features + x1ltop = self.shallow_feat1(x1ltop_img) + x1rtop = self.shallow_feat1(x1rtop_img) + x1lbot = self.shallow_feat1(x1lbot_img) + x1rbot = self.shallow_feat1(x1rbot_img) + + # Process features of all 4 patches with Encoder of Stage 1 + feat1_ltop = self.stage1_encoder(x1ltop) + feat1_rtop = self.stage1_encoder(x1rtop) + feat1_lbot = self.stage1_encoder(x1lbot) + feat1_rbot = self.stage1_encoder(x1rbot) + + # Concat deep features + feat1_top = [torch.cat((k, v), 3) for k, v in zip(feat1_ltop, feat1_rtop)] + feat1_bot = [torch.cat((k, v), 3) for k, v in zip(feat1_lbot, feat1_rbot)] + + # Pass features through Decoder of Stage 1 + res1_top = self.stage1_decoder(feat1_top) + res1_bot = self.stage1_decoder(feat1_bot) + + # Apply Supervised Attention Module (SAM) + # x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img) + x2top_samfeats, _ = self.sam12(res1_top[0], x2top_img) + # x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img) + x2bot_samfeats, _ = self.sam12(res1_bot[0], x2bot_img) + + # Output image at Stage 1 + # stage1_img = torch.cat([stage1_img_top, stage1_img_bot], 2) + + # Stage 2 + + # Compute Shallow Features + x2top = self.shallow_feat2(x2top_img) + x2bot = self.shallow_feat2(x2bot_img) + + # Concatenate SAM features of Stage 1 with shallow features of Stage 2 + x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1)) + x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1)) + + # Process features of both patches with Encoder of Stage 2 + feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top) + feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot) + + # Concat deep features + feat2 = [torch.cat((k, v), 2) for k, v in zip(feat2_top, feat2_bot)] + + # Pass features through Decoder of Stage 2 + res2 = self.stage2_decoder(feat2) + + # Apply SAM + # x3_samfeats, stage2_img = self.sam23(res2[0], x3_img) + x3_samfeats, _ = self.sam23(res2[0], x3_img) + + # Stage 3 + + # Compute Shallow Features + x3 = self.shallow_feat3(x3_img) + + # Concatenate SAM features of Stage 2 with shallow features of Stage 3 + x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1)) + + x3_cat = self.stage3_orsnet(x3_cat, feat2, res2) + + stage3_img = self.tail(x3_cat) + + return stage3_img + x3_img diff --git a/powerqe/archs/rbqe_arch.py b/powerqe/archs/rbqe_arch.py new file mode 100644 index 0000000..98ca53b --- /dev/null +++ b/powerqe/archs/rbqe_arch.py @@ -0,0 +1,750 @@ +import math +import numbers + +import torch +import torch.nn as nn +import torch.nn.functional as nn_func + +from .registry import ARCH_REGISTRY + + +class ECA(nn.Module): + """Efficient Channel Attention. + + Ref: "https://github.com/BangguWu/ECANet/blob + /3adf7a99f829ffa2e94a0de1de8a362614d66958/models/eca_module.py#L5" + + Args: + k_size: Kernel size. + """ + + def __init__(self, k_size=3): + super().__init__() + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d( + in_channels=1, + out_channels=1, + kernel_size=k_size, + padding=(k_size - 1) // 2, + bias=False, + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # (N, C, H, W) -> (N, C, 1, 1) + # -> (N, C, 1) -> (N, 1, C) -> Conv (just like FC, but ks=3) + # -> (N, 1, C) -> (N, C, 1) -> (N, C, 1, 1) + logic = self.avg_pool(x) + logic = ( + self.conv(logic.squeeze(-1).transpose(-1, -2)) + .transpose(-1, -2) + .unsqueeze(-1) + ) + logic = self.sigmoid(logic) + return x * logic.expand_as(x) + + +class SeparableConv2d(nn.Module): + def __init__(self, nf_in, nf_out): + super().__init__() + + self.separable_conv = nn.Sequential( + nn.Conv2d( + in_channels=nf_in, + out_channels=nf_in, + kernel_size=3, + padding=3 // 2, + groups=nf_in, # each channel is convolved with its own filter + ), + nn.Conv2d( + in_channels=nf_in, out_channels=nf_out, kernel_size=1, groups=1 + ), # then point-wise + ) + + def forward(self, x): + return self.separable_conv(x) + + +class GaussianSmoothing(nn.Module): + """Apply gaussian smoothing on a 1d, 2d or 3d tensor. + + Filtering is performed separately for each channel + in the input using a depthwise convolution. + + Args: + channels (int, sequence): Number of channels of the input tensors. + Output will have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + """ + + def __init__(self, channels, kernel_size, sigma, padding, dim=2): + super().__init__() + + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [torch.arange(size, dtype=torch.float32) for size in kernel_size] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= ( + 1 + / (std * math.sqrt(2 * math.pi)) + * torch.exp(-(((mgrid - mean) / std) ** 2) / 2) + ) # ignore the warning: it is a tensor + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) # ignore the warning: it is a tensor + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + self.padding = padding + + if dim == 1: + self.conv = nn_func.conv1d + elif dim == 2: + self.conv = nn_func.conv2d + elif dim == 3: + self.conv = nn_func.conv3d + else: + raise ValueError( + "Data with 1/2/3 dimensions is supported;" + f" received {dim} dimensions." + ) + + def forward(self, x): + """Apply gaussian filter to input. + + Args: + x (Tensor): Input to apply gaussian filter on. + + Returns: + Tensor: Filtered output. + """ + return self.conv( + x, weight=self.weight, groups=self.groups, padding=self.padding + ) + + +class IQAM: + def __init__(self, comp_type="jpeg"): + if comp_type == "jpeg": + self.patch_sz = 8 + + self.tche_poly = torch.tensor( + [ + [0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536], + [ + -0.5401, + -0.3858, + -0.2315, + -0.0772, + 0.0772, + 0.2315, + 0.3858, + 0.5401, + ], + [ + 0.5401, + 0.0772, + -0.2315, + -0.3858, + -0.3858, + -0.2315, + 0.0772, + 0.5401, + ], + [ + -0.4308, + 0.3077, + 0.4308, + 0.1846, + -0.1846, + -0.4308, + -0.3077, + 0.4308, + ], + [ + 0.2820, + -0.5238, + -0.1209, + 0.3626, + 0.3626, + -0.1209, + -0.5238, + 0.2820, + ], + [ + -0.1498, + 0.4922, + -0.3638, + -0.3210, + 0.3210, + 0.3638, + -0.4922, + 0.1498, + ], + [ + 0.0615, + -0.3077, + 0.5539, + -0.3077, + -0.3077, + 0.5539, + -0.3077, + 0.0615, + ], + [ + -0.0171, + 0.1195, + -0.3585, + 0.5974, + -0.5974, + 0.3585, + -0.1195, + 0.0171, + ], + ], + dtype=torch.float32, + ).cuda() + + self.thr_out = 0.855 + + elif comp_type == "hevc": + self.patch_sz = 4 + + self.tche_poly = torch.tensor( + [ + [0.5000, 0.5000, 0.5000, 0.5000], + [-0.6708, -0.2236, 0.2236, 0.6708], + [0.5000, -0.5000, -0.5000, 0.5000], + [-0.2236, 0.6708, -0.6708, 0.2236], + ], + dtype=torch.float32, + ).cuda() + + self.thr_out = 0.900 + + self.tche_poly_transposed = self.tche_poly.permute(1, 0) # h <-> w + + self.thr_smooth = torch.tensor(0.004) + self.thr_jnd = torch.tensor(0.05) + self.bigc = torch.tensor(1e-5) # numerical stability + self.alpha_block = 0.9 # [0, 1] + + self.gaussian_filter = GaussianSmoothing( + channels=1, kernel_size=3, sigma=5, padding=3 // 2 + ).cuda() + + def cal_tchebichef_moments(self, x): + x = x.clone() + x /= torch.sqrt( + self.patch_sz * self.patch_sz * (x.reshape((-1,)).pow(2).mean()) + ) + x -= x.reshape((-1,)).mean() + moments = torch.mm(torch.mm(self.tche_poly, x), self.tche_poly_transposed) + return moments + + def forward(self, x): + """Forward. + + Only test one channel, e.g., red. + + Args: + x (Tensor): Image with the shape of (B=1, C, H, W). + """ + h, w = x.shape[2:] + h_cut = h // self.patch_sz * self.patch_sz + w_cut = w // self.patch_sz * self.patch_sz + x = x[0, 0, :h_cut, :w_cut] # (h_cut, w_cut) + + num_smooth = 0.0 + num_textured = 0.0 + score_blocky_smooth = 0.0 + score_blurred_textured = 0.0 + + start_h = self.patch_sz // 2 - 1 + while start_h + self.patch_sz <= h_cut: + start_w = self.patch_sz // 2 - 1 + + while start_w + self.patch_sz <= w_cut: + patch = x[ + start_h : (start_h + self.patch_sz), + start_w : (start_w + self.patch_sz), + ] + + sum_patch = torch.sum(torch.abs(patch)) + # will lead to NAN score of blocky smooth patch + if sum_patch == 0: + num_smooth += 1 + score_blocky_smooth = score_blocky_smooth + 1.0 + + else: + moments_patch = self.cal_tchebichef_moments(patch) + + # smooth/textured patch + ssm = torch.sum(moments_patch.pow(2)) - moments_patch[0, 0].pow(2) + if ssm > self.thr_smooth: + num_textured += 1 + + patch_blurred = torch.squeeze( + self.gaussian_filter( + patch.clone().view(1, 1, self.patch_sz, self.patch_sz) + ) + ) + moments_patch_blurred = self.cal_tchebichef_moments( + patch_blurred + ) + similarity_matrix = torch.div( + ( + torch.mul(moments_patch, moments_patch_blurred) * 2.0 + + self.bigc + ), + (moments_patch.pow(2)) + + moments_patch_blurred.pow(2) + + self.bigc, + ) + score_blurred_textured += 1 - torch.mean( + similarity_matrix.reshape((-1)) + ) + + else: + num_smooth += 1 + + sum_moments = torch.sum(torch.abs(moments_patch)) + strength_vertical = ( + torch.sum(torch.abs(moments_patch[self.patch_sz - 1, :])) + / sum_moments + - torch.abs(moments_patch[0, 0]) + + self.bigc + ) + strength_horizontal = ( + torch.sum(torch.abs(moments_patch[:, self.patch_sz - 1])) + / sum_moments + - torch.abs(moments_patch[0, 0]) + + self.bigc + ) + + if strength_vertical > self.thr_jnd: + strength_vertical = self.thr_jnd + if strength_horizontal > self.thr_jnd: + strength_horizontal = self.thr_jnd + score_ = torch.log( + 1 - ((strength_vertical + strength_horizontal) / 2) + ) / torch.log(1 - self.thr_jnd) + + score_blocky_smooth = score_blocky_smooth + score_ + + start_w += self.patch_sz + start_h += self.patch_sz + + if num_textured != 0: + score_blurred_textured /= num_textured + else: + score_blurred_textured = torch.tensor(1.0, dtype=torch.float32) + if num_smooth != 0: + score_blocky_smooth /= num_smooth + else: + score_blocky_smooth = torch.tensor(1.0, dtype=torch.float32) + + score_quality = (score_blocky_smooth.pow(self.alpha_block)) * ( + score_blurred_textured.pow(1 - self.alpha_block) + ) + if score_quality >= self.thr_out: + return True + else: + return False + + +class Down(nn.Module): + # downsample for one time, e.g., from C2,1 to C3,2 + + def __init__(self, nf_in, nf_out, method, if_separable, if_eca): + super().__init__() + + supported_methods = ["avepool2d", "strideconv"] + if method not in supported_methods: + raise NotImplementedError( + f'Downsampling method should be in "{supported_methods}";' + f' received "{method}".' + ) + + if if_separable and if_eca: + layers = nn.ModuleList( + [ECA(k_size=3), SeparableConv2d(nf_in=nf_in, nf_out=nf_in)] + ) + elif if_separable and (not if_eca): + layers = nn.ModuleList([SeparableConv2d(nf_in=nf_in, nf_out=nf_in)]) + elif (not if_separable) and if_eca: + layers = nn.ModuleList( + [ + ECA(k_size=3), + nn.Conv2d( + in_channels=nf_in, + out_channels=nf_in, + kernel_size=3, + padding=3 // 2, + ), + ] + ) + else: + layers = nn.ModuleList( + [ + nn.Conv2d( + in_channels=nf_in, + out_channels=nf_in, + kernel_size=3, + padding=3 // 2, + ) + ] + ) + + if method == "avepool2d": + layers.append(nn.AvgPool2d(kernel_size=2)) + elif method == "strideconv": + layers.append( + nn.Conv2d( + in_channels=nf_in, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + stride=2, + ) + ) + + if if_separable and if_eca: + layers += [ + ECA(k_size=3), + SeparableConv2d(nf_in=nf_out, nf_out=nf_in), + ] + elif if_separable and (not if_eca): + layers.append(SeparableConv2d(nf_in=nf_out, nf_out=nf_in)) + elif (not if_separable) and if_eca: + layers += [ + ECA(k_size=3), + nn.Conv2d( + in_channels=nf_out, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + ), + ] + else: + layers.append( + nn.Conv2d( + in_channels=nf_out, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + ) + ) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class Up(nn.Module): + # upsample for one time, e.g., from C3,1 and C2,1 to C2,2 + + def __init__(self, nf_in_s, nf_in, nf_out, method, if_separable, if_eca): + super().__init__() + + supported_methods = ["upsample", "transpose2d"] + if method not in supported_methods: + raise NotImplementedError( + f'Upsampling method should be in "{supported_methods}";' + f' received "{method}".' + ) + + if method == "upsample": + self.up = nn.Upsample(scale_factor=2) + elif method == "transpose2d": + self.up = nn.ConvTranspose2d( + in_channels=nf_in_s, + out_channels=nf_out, + kernel_size=3, + stride=2, + padding=1, + ) + + if if_separable and if_eca: + layers = nn.ModuleList( + [ + ECA(k_size=3), + SeparableConv2d(nf_in=nf_in, nf_out=nf_out), + nn.ReLU(inplace=True), + ECA(k_size=3), + SeparableConv2d(nf_in=nf_out, nf_out=nf_out), + ] + ) + elif if_separable and (not if_eca): + layers = nn.ModuleList( + [ + SeparableConv2d(nf_in=nf_in, nf_out=nf_out), + nn.ReLU(inplace=True), + SeparableConv2d(nf_in=nf_out, nf_out=nf_out), + ] + ) + elif (not if_separable) and if_eca: + layers = nn.ModuleList( + [ + ECA(k_size=3), + nn.Conv2d( + in_channels=nf_in, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + ), + nn.ReLU(inplace=True), + ECA(k_size=3), + nn.Conv2d( + in_channels=nf_out, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + ), + ] + ) + else: + layers = nn.ModuleList( + [ + nn.Conv2d( + in_channels=nf_in, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + ), + nn.ReLU(inplace=True), + nn.Conv2d( + in_channels=nf_out, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + ), + ] + ) + self.layers = nn.Sequential(*layers) + + def forward(self, small_t, *normal_t_list): + feat = self.up(small_t) + + # pad feat according to a normal_t + if len(normal_t_list) > 0: + h_s, w_s = feat.size()[2:] # (N, C, H, W) + h, w = normal_t_list[0].size()[2:] + dh = h - h_s + dw = w - w_s + + if dh < 0: + feat = feat[:, :, :h, :] + dh = 0 + if dw < 0: + feat = feat[:, :, :, :w] + dw = 0 + feat = nn_func.pad( + input=feat, + pad=[dw // 2, (dw - dw // 2), dh // 2, (dh - dh // 2)], + mode="constant", + value=0, + ) + + feat = torch.cat((feat, *normal_t_list), dim=1) + + return self.layers(feat) + + +@ARCH_REGISTRY.register() +class RBQE(nn.Module): + def __init__( + self, + nf_io=3, + nf_base=32, + nlevel=5, + down_method="strideconv", + up_method="transpose2d", + if_separable=False, + if_eca=False, + if_only_last_output=True, + comp_type="hevc", + ): + super().__init__() + + self.nlevel = nlevel + self.if_only_last_output = if_only_last_output + + # input conv + if if_separable: + self.in_conv_seq = nn.Sequential( + SeparableConv2d(nf_in=nf_io, nf_out=nf_base), + nn.ReLU(inplace=True), + SeparableConv2d(nf_in=nf_base, nf_out=nf_base), + ) + else: + self.in_conv_seq = nn.Sequential( + nn.Conv2d( + in_channels=nf_io, + out_channels=nf_base, + kernel_size=3, + padding=3 // 2, + ), + nn.ReLU(inplace=True), + nn.Conv2d( + in_channels=nf_base, + out_channels=nf_base, + kernel_size=3, + padding=3 // 2, + ), + ) + + # down then up at each nested u-net + for idx_unet in range(nlevel): + setattr( + self, + f"down_{idx_unet}", + Down( + nf_in=nf_base, + nf_out=nf_base, + method=down_method, + if_separable=if_separable, + if_eca=if_eca, + ), + ) + for idx_up in range(idx_unet + 1): + setattr( + self, + f"up_{idx_unet}_{idx_up}", + Up( + nf_in_s=nf_base, + nf_in=nf_base * (2 + idx_up), # dense connection + nf_out=nf_base, + method=up_method, + if_separable=if_separable, + if_eca=if_eca, + ), + ) + + # output side + self.out_layers = nn.ModuleList() + if if_only_last_output: # single exit + repeat_times = 1 + else: # multi exits + repeat_times = nlevel + for _ in range(repeat_times): + if if_separable and if_eca: + self.out_layers.append( + nn.Sequential( + ECA(k_size=3), SeparableConv2d(nf_in=nf_base, nf_out=nf_io) + ) + ) + elif if_separable and (not if_eca): + self.out_layers.append(SeparableConv2d(nf_in=nf_base, nf_out=nf_io)) + elif (not if_separable) and if_eca: + self.out_layers.append( + nn.Sequential( + ECA(k_size=3), + nn.Conv2d( + in_channels=nf_base, + out_channels=nf_io, + kernel_size=3, + padding=3 // 2, + ), + ) + ) + else: + self.out_layers.append( + nn.Conv2d( + in_channels=nf_base, + out_channels=nf_io, + kernel_size=3, + padding=3 // 2, + ) + ) + + # IQA module + # no trainable parameters + if not if_only_last_output: # multi-exit network + self.iqam = IQAM(comp_type=comp_type) + + def forward(self, x, idx_out=None): + """Forward. + + Args: + x (Tensor): Image with the shape of (B=1, C, H, W). + idx_out (int): + -2: Determined by IQAM. + -1: Output all images from all outputs for training. + 0 | 1 | ... | self.nlevel-1: Output from the assigned exit. + None: Output from the last exit. + """ + if self.if_only_last_output: + if idx_out is not None: + raise ValueError( + "Exit cannot be indicated" " since there is only one exit." + ) + idx_out = self.nlevel - 1 + + feat = self.in_conv_seq(x) + feat_level_unet = [[feat]] # the first level feature of the first U-Net + + if idx_out == -1: # to record output images from all exits + out_img_list = [] + + for idx_unet in range(self.nlevel): # per U-Net + down = getattr(self, f"down_{idx_unet}") + feat = down(feat_level_unet[-1][0]) # the previous U-Net, the first level + feat_up_list = [feat] + + # for the first u-net (idx=0), up one time + for idx_up in range(idx_unet + 1): + dense_inp_list = [] + # To obtain C2,4 + # It is the second upsampling, idx_up == 2. + # It needs C2,1 to C2,3 at feat_level_unet[1][0], + # feat_level_unet[2][1] and feat_level_unet[3][2]. + # feat_level_unet now contains 4 lists. + for idx_, feat_level in enumerate(feat_level_unet[-(idx_up + 1) :]): + dense_inp_list.append( + feat_level[idx_] + ) # append features from previous U-Nets at the same level + + up = getattr(self, f"up_{idx_unet}_{idx_up}") + feat_up = up(feat_up_list[-1], *dense_inp_list) + feat_up_list.append(feat_up) + + if idx_out in [-1, -2, idx_unet]: # if go to the output side + if self.if_only_last_output: + out_conv_seq = self.out_layers[0] + else: + out_conv_seq = self.out_layers[idx_unet] + out_img = out_conv_seq(feat_up_list[-1]) + x + + if idx_out == -1: + out_img_list.append(out_img) + + # if at the last level, no need to IQA + if (idx_out == -2) and (idx_unet < (self.nlevel - 1)): + if_out = self.iqam.forward(out_img) + if if_out: + break + + feat_level_unet.append(feat_up_list) + + if idx_out == -1: + return torch.stack(out_img_list, dim=0) # (self.nlevel, N, C, H, W) + else: + return out_img # (B=1, C, H, W) diff --git a/powerqe/archs/rdn_arch.py b/powerqe/archs/rdn_arch.py new file mode 100644 index 0000000..b42a200 --- /dev/null +++ b/powerqe/archs/rdn_arch.py @@ -0,0 +1,196 @@ +import math + +import torch +from torch import nn + +from .registry import ARCH_REGISTRY + + +class DenseLayer(nn.Module): + """Dense layer. + + Args: + in_channels (int): Channel number of inputs. + out_channels (int): Channel number of outputs. + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=3 // 2) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c_in, h, w). + + Returns: + Tensor: Forward results, tensor with shape (n, c_in+c_out, h, w). + """ + return torch.cat([x, self.relu(self.conv(x))], 1) + + +class RDB(nn.Module): + """Residual Dense Block of Residual Dense Network. + + Args: + in_channels (int): Channel number of inputs. + channel_growth (int): Channels growth in each layer. + num_layers (int): Layer number in the Residual Dense Block. + """ + + def __init__(self, in_channels, channel_growth, num_layers): + super().__init__() + self.layers = nn.Sequential( + *[ + DenseLayer(in_channels + channel_growth * i, channel_growth) + for i in range(num_layers) + ] + ) + + # local feature fusion + self.lff = nn.Conv2d( + in_channels + channel_growth * num_layers, in_channels, kernel_size=1 + ) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + return x + self.lff(self.layers(x)) # local residual learning + + +class Interpolate(nn.Module): + # Ref: "https://discuss.pytorch.org/t + # /using-nn-function-interpolate-inside-nn-sequential/23588/2" + + def __init__(self, scale_factor, mode): + super().__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, x): + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False + ) + return x + + +@ARCH_REGISTRY.register() +class RDN(nn.Module): + """RDN for quality enhancement. + + Differences to the RDN in MMEditing: + Support rescaling before/after enhancement. + + Args: + rescale (int): Rescaling factor. + io_channels (int): Number of I/O channels. + mid_channels (int): Channel number of intermediate features. + num_blocks (int): Block number in the trunk network. + num_layers (int): Layer number in the Residual Dense Block. + channel_growth (int): Channels growth in each layer of RDB. + """ + + def __init__( + self, + rescale=1, + io_channels=3, + mid_channels=64, + num_blocks=8, + num_layers=8, + channel_growth=64, + ): + super().__init__() + + self.rescale = rescale + self.mid_channels = mid_channels + self.channel_growth = channel_growth + self.num_blocks = num_blocks + self.num_layers = num_layers + + if not math.log2(rescale).is_integer(): + raise ValueError(f"Rescale factor ({rescale}) should be a power of 2.") + + if rescale == 1: + self.downscale = nn.Identity() + else: + self.downscale = Interpolate(scale_factor=1.0 / rescale, mode="bicubic") + + # shallow feature extraction + self.sfe1 = nn.Conv2d(io_channels, mid_channels, kernel_size=3, padding=3 // 2) + self.sfe2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=3 // 2) + + # residual dense blocks + self.rdbs = nn.ModuleList() + for _ in range(self.num_blocks): + self.rdbs.append( + RDB(self.mid_channels, self.channel_growth, self.num_layers) + ) + + # global feature fusion + self.gff = nn.Sequential( + nn.Conv2d( + self.mid_channels * self.num_blocks, self.mid_channels, kernel_size=1 + ), + nn.Conv2d( + self.mid_channels, self.mid_channels, kernel_size=3, padding=3 // 2 + ), + ) + + # upsampling + if rescale == 1: + self.upscale = nn.Identity() + else: + self.upscale = [] + for _ in range(rescale // 2): + self.upscale.extend( + [ + nn.Conv2d( + self.mid_channels, + self.mid_channels * (2**2), + kernel_size=3, + padding=3 // 2, + ), + nn.PixelShuffle(2), + ] + ) + self.upscale = nn.Sequential(*self.upscale) + + self.output = nn.Conv2d( + self.mid_channels, io_channels, kernel_size=3, padding=3 // 2 + ) + + def forward(self, x): + """Forward. + + Args: + x (Tensor): Input tensor with the shape of (N, C, H, W). + + Returns: + Tensor + """ + x = self.downscale(x) + + sfe1 = self.sfe1(x) + sfe2 = self.sfe2(sfe1) + + x = sfe2 + local_features = [] + for i in range(self.num_blocks): + x = self.rdbs[i](x) + local_features.append(x) + + x = self.gff(torch.cat(local_features, 1)) + sfe1 # global residual learning + + x = self.upscale(x) + x = self.output(x) + return x