Skip to content

Training

Thanos Masouris edited this page Aug 28, 2022 · 1 revision

In order to train DiStyleGAN from scratch on CIFAR-10, follow the steps below:

  1. Clone the GitHub repository.
  2. Install the python packages in the requirements file. (Python 3.10)
  3. Download the FakeCIFAR10 dataset from here and extract the zip file, or recreate it using the create_dataset.py script, following the instructions here.
  4. Train the model using one of the following options:
  • Using the default configurations with the example below in Python
from distylegan import DiStyleGAN
model = DiStyleGAN()
model.train(
  dataset="./fakecifar/dataset", 
  save="results"
)
  • Using the command line options for the corresponding python script
$ python distylegan.py train -h
usage: distylegan.py train [-h] --dataset DATASET --save SAVE [--c_dim C_DIM] [--lambda_ganD LAMBDA_GAND] [--lambda_ganG LAMBDA_GANG] 
[--lambda_pixel LAMBDA_PIXEL] [--nc NC][--ndf NDF] [--ngf NGF] [--project_dim PROJECT_DIM] [--transform TRANSFORM] [--z_dim Z_DIM] 
[--adam_momentum ADAM_MOMENTUM] [--batch_size BATCH_SIZE] [--checkpoint_interval CHECKPOINT_INTERVAL] [--checkpoint_path CHECKPOINT_PATH] 
[--device DEVICE] [--epochs EPOCHS] [--gstep GSTEP] [--lr_D LR_D] [--lr_G LR_G] [--lr_decay LR_DECAY] [--num_test NUM_TEST] 
[--num_workers NUM_WORKERS] [--real_dataset REAL_DATASET]

options:
  -h, --help            show this help message and exit

Required arguments for the training procedure:
  --dataset DATASET     Path to the dataset directory of the fake CIFAR10 data generated by the teacher network
  --save SAVE           Path to save checkpoints and results

Optional arguments about the network configuration:
  --c_dim C_DIM         Condition dimension (Default: 10)
  --lambda_ganD LAMBDA_GAND
                        Weight for the adversarial GAN loss of the 
                        Discriminator (Default: 0.2)
  --lambda_ganG LAMBDA_GANG
                        Weight for the adversarial distillation loss 
                        of the Generator (Default: 0.01)
  --lambda_pixel LAMBDA_PIXEL
                        Weight for the pixel loss of the Generator      
                       (Default: 0.2)
  --nc NC               Number of channels for the images 
                        (Default: 3)
  --ndf NDF             Number of discriminator filters in the first 
                        convolutional layer (Default: 128)
  --ngf NGF             Number of generator filters in the first 
                        convolutional layer (Default: 256)
  --project_dim PROJECT_DIM
                        Dimension to project the input condition 
                       (Default: 128)
  --transform TRANSFORM
                        Optional transform to be applied on a sample 
                        image (Default: None)
  --z_dim Z_DIM         Noise dimension (Default: 512)

Optional arguments about the training procedure:
  --adam_momentum ADAM_MOMENTUM
                        Momentum value for the Adam optimizers' 
                        betas (Default: 0.5)
  --batch_size BATCH_SIZE
                        Number of samples per batch (Default: 128)
  --checkpoint_interval CHECKPOINT_INTERVAL
                        Checkpoints will be saved every `
                        checkpoint_interval` epochs (Default: 20)
  --checkpoint_path CHECKPOINT_PATH
                        Path to previous checkpoint
  --device DEVICE       Device to use for training ('cpu' or 'cuda') 
                       (Default: If there is a CUDA device      
                        available, it will be used for training)
  --epochs EPOCHS       Number of training epochs (Default: 150)
  --gstep GSTEP         The number of discriminator updates after 
                        which the generator is updated using the 
                        full loss(Default: 10)
  --lr_D LR_D           Learning rate for the discriminator's Adam 
                        optimizer (Default: 0.0002)
  --lr_G LR_G           Learning rate for the generator's Adam 
                        optimizer (Default: 0.0002)
  --lr_decay LR_DECAY   Iteration to start decaying the learning 
                        rates for the Generator and the 
                        Discriminator(Default: 350000)
  --num_test NUM_TEST   Number of generated images for evaluation 
                       (Default: 30)
  --num_workers NUM_WORKERS
                        Νumber of subprocesses to use for data 
                        loading (Default: 0, whichs means that the 
                        data will be loaded in the main process.)
  --real_dataset REAL_DATASET
                        Path to the dataset directory of the real         
                        CIFAR10 data. (Default: None, it will be 
                        downloaded and saved in the parent directory 
                        of input `dataset` path)