diff --git a/docs/source/training.rst b/docs/source/training.rst index c9fceb05f..dcdebff83 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -86,6 +86,20 @@ Dataset: https://joligen.com/datasets/noglasses2glasses_ffhq.zip python3 train.py --dataroot /path/to/noglasses2glasses_ffhq/ --checkpoints_dir /path/to/checkpoints/ --name noglasses2glasses --config_json examples/example_cut_turbo_noglasses2glasses.json +.. _training-im2im-without-mask-semantics: + +*************************************** + CUT_Turbo Training without semantics +*************************************** + +Trains a GAN model to using a pretrained SD-Turbo model with LoRA adapter + +Dataset: https://joligen.com/datasets/horse2zebra.zip + +.. code:: bash + + python3 train.py --dataroot /path/to/horse2zebra/ --checkpoints_dir /path/to/checkpoints/ --name horse2zebra --config_json examples/example_cut_turbo_horse2zebra.json + .. _training-im2im-with-bbox-semantics-and-online-sampling-boxes-dataaug: ************************************************ diff --git a/examples/example_cut_turbo_horse2zebra.json b/examples/example_cut_turbo_horse2zebra.json new file mode 100644 index 000000000..deadbb201 --- /dev/null +++ b/examples/example_cut_turbo_horse2zebra.json @@ -0,0 +1,132 @@ +{ + "D": { + "dropout": false, + "n_layers": 3, + "ndf": 64, + "netDs": ["projected_d", "basic"], + "norm": "instance", + "proj_interp": -1, + "proj_network_type": "efficientnet", + }, + "G": { + "attn_nb_mask_attn": 10, + "attn_nb_mask_input": 1, + "dropout": false, + "nblocks": 9, + "netG": "img2img_turbo", + "ngf": 64, + "norm": "instance", + "padding_type": "reflect", + }, + "alg": { + "gan": {"lambda": 1.0}, + "cut": { + "HDCE_gamma": 1.0, + "HDCE_gamma_min": 1.0, + "MSE_idt": false, + "flip_equivariance": false, + "lambda_MSE_idt": 1.0, + "lambda_NCE": 1.0, + "lambda_SRC": 0.0, + "nce_T": 0.07, + "nce_idt": true, + "nce_includes_all_negatives_from_minibatch": false, + "nce_layers": "0,4,8,12,16", + "nce_loss": "monce", + "netF": "mlp_sample", + "netF_dropout": false, + "netF_nc": 256, + "netF_norm": "instance", + "num_patches": 256, + }, + }, + "data": { + "crop_size": 256, + "dataset_mode": "unaligned", + "direction": "AtoB", + "load_size": 256, + "max_dataset_size": 1000000000, + "num_threads": 4, + "preprocess": "resize_and_crop", + }, + "output": { + "display": { + "freq": 400, + "id": 1, + "ncols": 0, + "type": ["visdom"], + "visdom_port": 8097, + "visdom_server": "http://localhost", + "winsize": 256, + }, + "no_html": false, + "print_freq": 100, + "update_html_freq": 1000, + "verbose": false, + }, + "model": { + "init_gain": 0.02, + "init_type": "normal", + "input_nc": 3, + "multimodal": false, + "output_nc": 3, + }, + "train": { + "D_lr": 0.0001, + "G_ema": false, + "G_ema_beta": 0.999, + "G_lr": 0.0001, + "batch_size": 4, + "beta1": 0.9, + "beta2": 0.999, + "continue": false, + "epoch": "latest", + "epoch_count": 1, + "export_jit": false, + "gan_mode": "lsgan", + "iter_size": 8, + "load_iter": 0, + "metrics_every": 1000, + "n_epochs": 200, + "n_epochs_decay": 100, + "nb_img_max_fid": 1000000000, + "optim": "adam", + "pool_size": 50, + "save_by_iter": false, + "save_epoch_freq": 1, + "save_latest_freq": 5000, + }, + "dataaug": { + "APA": false, + "APA_every": 4, + "APA_nimg": 50, + "APA_p": 0, + "APA_target": 0.6, + "D_diffusion": false, + "D_diffusion_every": 4, + "D_label_smooth": false, + "D_noise": 0.0, + "affine": 0.0, + "affine_scale_max": 1.2, + "affine_scale_min": 0.8, + "affine_shear": 45, + "affine_translate": 0.2, + "diff_aug_policy": "", + "diff_aug_proba": 0.5, + "imgaug": false, + "no_flip": false, + "no_rotate": true, + }, + "checkpoints_dir": "/path/to/checkpoints", + "dataroot": "/path/to/horse2zebra", + "ddp_port": "12355", + "gpu_ids": "0", + "model_type": "cut", + "name": "horse2zebra", + "phase": "train", + "test_batch_size": 1, + "warning_mode": false, + "with_amp": false, + "with_tf32": false, + "with_torch_compile": false, +} diff --git a/models/cut_model.py b/models/cut_model.py index f4784ea88..a83768950 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -238,9 +238,8 @@ def __init__(self, opt, rank): self.opt.model_input_nc += self.opt.train_mm_nz self.netG_A = gan_networks.define_G(**vars(opt)) - # XXX: early prompt support - # if self.opt.G_prompt: - # self.netG_A.prompt = self.opt.G_prompt + if self.opt.G_netG == "img2img_turbo" and self.opt.G_prompt: + self.prompt_opt = [self.opt.G_prompt] * opt.train_batch_size self.netG_A.lora_rank_unet = self.opt.G_lora_unet self.netG_A.lora_rank_vae = self.opt.G_lora_vae @@ -556,7 +555,13 @@ def inference(self, nb_imgs, offset=0): self.real_with_z = self.real if self.opt.G_netG == "img2img_turbo": - self.fake = self.netG_A(self.real_with_z, self.real_B_prompt) + prompt = self.prompt_opt if self.opt.G_prompt else self.real_B_prompt + if len(prompt) != self.real_with_z.shape[0]: + prompt = prompt * ( + self.real_with_z.shape[0] // self.opt.train_batch_size + ) + + self.fake = self.netG_A(self.real_with_z, prompt) else: self.fake = self.netG_A(self.real_with_z) @@ -602,7 +607,13 @@ def forward_cut(self): self.real_with_z = self.real if self.opt.G_netG == "img2img_turbo": - self.fake = self.netG_A(self.real_with_z, self.real_B_prompt) + prompt = self.prompt_opt if self.opt.G_prompt else self.real_B_prompt + if len(prompt) != self.real_with_z.shape[0]: + prompt = prompt * ( + self.real_with_z.shape[0] // self.opt.train_batch_size + ) + + self.fake = self.netG_A(self.real_with_z, prompt) else: self.fake = self.netG_A(self.real_with_z) @@ -659,10 +670,17 @@ def forward_E(self): self.real_A.size(3), ) real_A_with_z = torch.cat([self.real_A, z_real], 1) + if self.opt.G_netG == "img2img_turbo": - fake_B = self.netG_A(real_A_with_z, self.real_B_prompt) + prompt = self.prompt_opt if self.opt.G_prompt else self.real_B_prompt + if len(prompt) != self.real_with_z.shape[0]: + prompt = prompt * ( + self.real_with_z.shape[0] // self.opt.train_batch_size + ) + + fake_B = self.netG_A(self.real_with_z, prompt) else: - fake_B = self.netG_A(real_A_with_z) + fake_B = self.netG_A(self.real_with_z) self.mu2 = self.netE(fake_B) def compute_G_loss_cut(self): diff --git a/models/modules/img2img_turbo/img2img_turbo.py b/models/modules/img2img_turbo/img2img_turbo.py index 2efb3889a..d04ea6455 100644 --- a/models/modules/img2img_turbo/img2img_turbo.py +++ b/models/modules/img2img_turbo/img2img_turbo.py @@ -201,9 +201,6 @@ def forward(self, x, prompt): ).input_ids.cuda() caption_enc = self.text_encoder(caption_tokens)[0] - # match batch size - captions_enc = caption_enc.repeat(x.shape[0], 1, 1) - # deterministic forward encoded_control = ( self.vae.encode(x).latent_dist.sample() * self.vae.config.scaling_factor @@ -211,7 +208,7 @@ def forward(self, x, prompt): model_pred = self.unet( encoded_control, self.timesteps, - encoder_hidden_states=captions_enc, + encoder_hidden_states=caption_enc, ).sample x_denoised = self.sched.step( model_pred, self.timesteps, encoded_control, return_dict=True @@ -223,14 +220,6 @@ def forward(self, x, prompt): return x def compute_feats(self, input, extract_layer_ids=[]): - # caption_tokens = self.tokenizer( - # #self.prompt, # XXX: set externally - # prompt, - # max_length=self.tokenizer.model_max_length, - # padding="max_length", - # truncation=True, - # return_tensors="pt", - # ).input_ids.cuda() # deterministic forward encoded_control = ( diff --git a/options/common_options.py b/options/common_options.py index a4bc5b481..26e6a6198 100644 --- a/options/common_options.py +++ b/options/common_options.py @@ -387,12 +387,12 @@ def initialize(self, parser): help="Patch size for HDIT, e.g. 4 for 4x4 patches", ) - # parser.add_argument( - # "--G_prompt", - # type=str, - # default="", - # help="Text prompt for G", - # ) + parser.add_argument( + "--G_prompt", + type=str, + default="", + help="Text prompt for G", + ) parser.add_argument( "--G_lora_unet", type=int,