-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support DIY models/archs; test benchmarks
- Loading branch information
1 parent
aed540d
commit 1e0a17c
Showing
17 changed files
with
212 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# general settings | ||
name: test_identity_DIV2K_G1 | ||
model_type: QEModel | ||
scale: 1 | ||
num_gpu: 1 # set num_gpu: 0 for cpu mode | ||
manual_seed: 0 | ||
|
||
# dataset settings | ||
datasets: | ||
test: # multiple test datasets are acceptable | ||
name: DIV2K | ||
type: PairedImageDataset | ||
dataroot_gt: datasets/DIV2K/valid | ||
dataroot_lq: datasets/DIV2K/valid_BPG_QP37 | ||
io_backend: | ||
type: disk | ||
|
||
# network structures | ||
network_g: | ||
type: IdentityNet | ||
scale: 1 # default scale=4 | ||
|
||
# path | ||
path: | ||
pretrain_network_g: ~ | ||
strict_load_g: ~ | ||
|
||
# validation settings | ||
val: | ||
save_img: true # save img -> tensor -> img version, which is lossy | ||
suffix: ~ # add suffix to saved images, if None, use exp name | ||
|
||
metrics: | ||
psnr: | ||
type: calculate_psnr | ||
crop_border: 0 | ||
test_y_channel: false | ||
ssim: | ||
type: calculate_ssim | ||
crop_border: 0 | ||
test_y_channel: false | ||
fid: | ||
type: pyiqa | ||
better: lower | ||
lpips: | ||
type: pyiqa | ||
better: lower |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .builder import build_network | ||
from .registry import ARCH_REGISTRY | ||
from .identitynet_arch import IdentityNet | ||
|
||
__all__ = ["build_network", "ARCH_REGISTRY", "IdentityNet"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from copy import deepcopy | ||
|
||
from basicsr.utils import get_root_logger | ||
|
||
from .registry import ARCH_REGISTRY | ||
|
||
|
||
def build_network(opt): | ||
opt = deepcopy(opt) | ||
network_type = opt.pop("type") | ||
net = ARCH_REGISTRY.get(network_type)(**opt) | ||
logger = get_root_logger() | ||
logger.info(f"Network [{net.__class__.__name__}] is created.") | ||
return net |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
|
||
from .registry import ARCH_REGISTRY | ||
|
||
|
||
@ARCH_REGISTRY.register() | ||
class IdentityNet(nn.Module): | ||
"""Identity network used for testing benchmarks (in tensors). Support up-scaling.""" | ||
|
||
def __init__(self, scale=1, upscale_mode="nearest"): | ||
super(IdentityNet, self).__init__() | ||
self.scale = scale | ||
self.upscale_mode = upscale_mode | ||
|
||
def forward(self, x): | ||
if self.scale != 1: | ||
x = F.interpolate(x, scale_factor=self.scale, mode=self.upscale_mode) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from basicsr.utils.registry import ARCH_REGISTRY as ARCH_REGISTRY_BASICSR | ||
|
||
ARCH_REGISTRY = ARCH_REGISTRY_BASICSR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .builder import build_model | ||
from .registry import MODEL_REGISTRY | ||
from .qe_model import QEModel | ||
|
||
__all__ = ["build_model", "MODEL_REGISTRY", "QEModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from copy import deepcopy | ||
|
||
from basicsr.utils import get_root_logger | ||
|
||
from .registry import MODEL_REGISTRY | ||
|
||
|
||
def build_model(opt): | ||
"""Build model from options. | ||
Args: | ||
opt (dict): Configuration. It must contain: | ||
model_type (str): Model type. | ||
""" | ||
opt = deepcopy(opt) | ||
model = MODEL_REGISTRY.get(opt["model_type"])(opt) | ||
logger = get_root_logger() | ||
logger.info(f"Model [{model.__class__.__name__}] is created.") | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from basicsr.models.sr_model import SRModel | ||
|
||
from powerqe.archs import build_network | ||
|
||
from .registry import MODEL_REGISTRY | ||
|
||
|
||
@MODEL_REGISTRY.register() | ||
class QEModel(SRModel): | ||
"""Base QE model for single image quality enhancement.""" | ||
|
||
def __init__(self, opt): | ||
super(SRModel, self).__init__(opt) | ||
|
||
# define network | ||
self.net_g = build_network(opt["network_g"]) | ||
self.net_g = self.model_to_device(self.net_g) | ||
self.print_network(self.net_g) | ||
|
||
# load pretrained models | ||
load_path = self.opt["path"].get("pretrain_network_g", None) | ||
if load_path is not None: | ||
param_key = self.opt["path"].get("param_key_g", "params") | ||
self.load_network( | ||
self.net_g, | ||
load_path, | ||
self.opt["path"].get("strict_load_g", True), | ||
param_key, | ||
) | ||
|
||
if self.is_train: | ||
self.init_training_settings() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from basicsr.utils.registry import MODEL_REGISTRY as MODEL_REGISTRY_BASICSR | ||
|
||
|
||
MODEL_REGISTRY = MODEL_REGISTRY_BASICSR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#!/usr/bin/env bash | ||
|
||
GPUS=$1 | ||
CONFIG=$2 | ||
PORT=${PORT:-4321} | ||
|
||
# usage | ||
if [ $# -lt 2 ] ;then | ||
echo "usage:" | ||
echo "./scripts/test.sh [number of gpu] [path to option file]" | ||
exit | ||
fi | ||
|
||
# check if GPUS is 1 for single-GPU, otherwise run multi-GPU | ||
if [ "$GPUS" -eq 1 ]; then | ||
# if only one GPU, run the simple version | ||
PYTHONPATH="$(dirname $0)/..:${PYTHONPATH}" \ | ||
python powerqe/test.py -opt $CONFIG ${@:3} | ||
else | ||
# if multiple GPUs, run the distributed version | ||
PYTHONPATH="$(dirname $0)/..:${PYTHONPATH}" \ | ||
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ | ||
powerqe/test.py -opt $CONFIG --launcher pytorch ${@:3} | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#!/usr/bin/env bash | ||
|
||
GPUS=$1 | ||
CONFIG=$2 | ||
PORT=${PORT:-4321} | ||
|
||
# usage | ||
if [ $# -lt 2 ] ;then | ||
echo "usage:" | ||
echo "./scripts/train.sh [number of gpu] [path to option file]" | ||
exit | ||
fi | ||
|
||
# check if GPUS is 1 for single-GPU, otherwise run multi-GPU | ||
if [ "$GPUS" -eq 1 ]; then | ||
# single GPU version | ||
PYTHONPATH="$(dirname $0)/..:${PYTHONPATH}" \ | ||
python powerqe/train.py -opt $CONFIG ${@:3} | ||
else | ||
# multi-GPU version | ||
PYTHONPATH="$(dirname $0)/..:${PYTHONPATH}" \ | ||
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ | ||
powerqe/train.py -opt $CONFIG --launcher pytorch ${@:3} | ||
fi |