-
Notifications
You must be signed in to change notification settings - Fork 2
Training
Thanos Masouris edited this page Aug 28, 2022
·
1 revision
In order to train DiStyleGAN from scratch on CIFAR-10, follow the steps below:
- Clone the GitHub repository.
- Install the python packages in the requirements file. (Python 3.10)
- Download the FakeCIFAR10 dataset from here and extract the zip file, or recreate it using the create_dataset.py script, following the instructions here.
- Train the model using one of the following options:
- Using the default configurations with the example below in Python
from distylegan import DiStyleGAN
model = DiStyleGAN()
model.train(
dataset="./fakecifar/dataset",
save="results"
)
- Using the command line options for the corresponding python script
$ python distylegan.py train -h
usage: distylegan.py train [-h] --dataset DATASET --save SAVE [--c_dim C_DIM] [--lambda_ganD LAMBDA_GAND] [--lambda_ganG LAMBDA_GANG]
[--lambda_pixel LAMBDA_PIXEL] [--nc NC][--ndf NDF] [--ngf NGF] [--project_dim PROJECT_DIM] [--transform TRANSFORM] [--z_dim Z_DIM]
[--adam_momentum ADAM_MOMENTUM] [--batch_size BATCH_SIZE] [--checkpoint_interval CHECKPOINT_INTERVAL] [--checkpoint_path CHECKPOINT_PATH]
[--device DEVICE] [--epochs EPOCHS] [--gstep GSTEP] [--lr_D LR_D] [--lr_G LR_G] [--lr_decay LR_DECAY] [--num_test NUM_TEST]
[--num_workers NUM_WORKERS] [--real_dataset REAL_DATASET]
options:
-h, --help show this help message and exit
Required arguments for the training procedure:
--dataset DATASET Path to the dataset directory of the fake CIFAR10 data generated by the teacher network
--save SAVE Path to save checkpoints and results
Optional arguments about the network configuration:
--c_dim C_DIM Condition dimension (Default: 10)
--lambda_ganD LAMBDA_GAND
Weight for the adversarial GAN loss of the
Discriminator (Default: 0.2)
--lambda_ganG LAMBDA_GANG
Weight for the adversarial distillation loss
of the Generator (Default: 0.01)
--lambda_pixel LAMBDA_PIXEL
Weight for the pixel loss of the Generator
(Default: 0.2)
--nc NC Number of channels for the images
(Default: 3)
--ndf NDF Number of discriminator filters in the first
convolutional layer (Default: 128)
--ngf NGF Number of generator filters in the first
convolutional layer (Default: 256)
--project_dim PROJECT_DIM
Dimension to project the input condition
(Default: 128)
--transform TRANSFORM
Optional transform to be applied on a sample
image (Default: None)
--z_dim Z_DIM Noise dimension (Default: 512)
Optional arguments about the training procedure:
--adam_momentum ADAM_MOMENTUM
Momentum value for the Adam optimizers'
betas (Default: 0.5)
--batch_size BATCH_SIZE
Number of samples per batch (Default: 128)
--checkpoint_interval CHECKPOINT_INTERVAL
Checkpoints will be saved every `
checkpoint_interval` epochs (Default: 20)
--checkpoint_path CHECKPOINT_PATH
Path to previous checkpoint
--device DEVICE Device to use for training ('cpu' or 'cuda')
(Default: If there is a CUDA device
available, it will be used for training)
--epochs EPOCHS Number of training epochs (Default: 150)
--gstep GSTEP The number of discriminator updates after
which the generator is updated using the
full loss(Default: 10)
--lr_D LR_D Learning rate for the discriminator's Adam
optimizer (Default: 0.0002)
--lr_G LR_G Learning rate for the generator's Adam
optimizer (Default: 0.0002)
--lr_decay LR_DECAY Iteration to start decaying the learning
rates for the Generator and the
Discriminator(Default: 350000)
--num_test NUM_TEST Number of generated images for evaluation
(Default: 30)
--num_workers NUM_WORKERS
Νumber of subprocesses to use for data
loading (Default: 0, whichs means that the
data will be loaded in the main process.)
--real_dataset REAL_DATASET
Path to the dataset directory of the real
CIFAR10 data. (Default: None, it will be
downloaded and saved in the parent directory
of input `dataset` path)