-
Notifications
You must be signed in to change notification settings - Fork 3
Home
- src: source, e.g., raw images.
- dst: distorted, e.g., compressed images.
- tar: target, e.g., enhanced compressed images.
Support all PNG image data-sets, e.g., DIV2K and Flickr2K.
You may down-sample the images if the GPU memory is limited.
Please pay attention to all # FIXME
items.
Set max_num=-1
and start=0
.
Each stage is independent of other stages. Thus, the loss functions, optimizers and schedulers should be assigned for each stage, respectively.
Note that the network is the same for all stages.
train:
niter:
name: [stage1, stage2]
niter: [!!float 4e+5, !!float 4e+5]
loss:
stage1:
...
stage2:
...
You can assign different optimizers for different loss functions.
For example, at the second stage of training GANs, the generator and the discriminator possess their own optimizers. We say we have two groups of losses named dis
and gen
.
train:
loss:
stage1: # only calculate group1 is ok
group1:
stage2:
group1:
...
group2:
...
Note that in default, all optimizers possess all parameters. If you want each optimizer to possess different parameters, edit your own algorithm, especially create_optimizer
function. Check ESRGAN algorithm in this repo for an example.
nworker_pg
is recommended to be equal to bs_pg
.
In general, real_bs_pg
should be equal to bs_pg
.
If you want to use one GPU to act as N GPUs, just edit the real_bs_pg
to be N * bs_pg
.
You can stop training at any time.
When you re-start the training and the load_state/if_load
is set to be True
, the training programme will automatically find and re-load the load_state/opts/ckp_load_path
, e.g., exp/<exp_name>/ckp_last.pt
in default.
If you want to obtain enhanced images without evaluation, you can simply set crit
in YML as ~
(None).
DnCNN is a modified version, which turns off the batch normalization (BN). Check here for details.
CBDNet is a modified version, which is trained in an end-to-end manner without total variation (TV) loss on the noise estimation map. Check here for details.
In network forward
function, idx_out
indicates the index of exit.
-
-1
: output all images from all exits. Used for training. -
-2
: judge by IQA Module. -
0, 1, 2, 3, 4
: assign exit.
For validation in training, we assign exit according to the QP/QF.
In the official repo, the MATLAB-based IQAM is fed with Y channel input. Here for simplicity, we re-implement IQAM with Python and feed it with R channel input. Hyper-parameters such as exit thresholds have been tuned over the current test set. Since the re-implemented Python-based IQAM is much, much slower than the official MATLAB-based one, we record the FPS without IQAM in this repo.