Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

G_prompt for cut_turbo for dataset with single prompt #662

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ Dataset: https://joligen.com/datasets/noglasses2glasses_ffhq.zip

python3 train.py --dataroot /path/to/noglasses2glasses_ffhq/ --checkpoints_dir /path/to/checkpoints/ --name noglasses2glasses --config_json examples/example_cut_turbo_noglasses2glasses.json

.. _training-im2im-without-mask-semantics:

***************************************
CUT_Turbo Training without semantics
***************************************

Trains a GAN model to using a pretrained SD-Turbo model with LoRA adapter

Dataset: https://joligen.com/datasets/horse2zebra.zip

.. code:: bash

python3 train.py --dataroot /path/to/horse2zebra/ --checkpoints_dir /path/to/checkpoints/ --name horse2zebra --config_json examples/example_cut_turbo_horse2zebra.json

.. _training-im2im-with-bbox-semantics-and-online-sampling-boxes-dataaug:

************************************************
Expand Down
132 changes: 132 additions & 0 deletions examples/example_cut_turbo_horse2zebra.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
{
"D": {
"dropout": false,
"n_layers": 3,
"ndf": 64,
"netDs": ["projected_d", "basic"],
"norm": "instance",
"proj_interp": -1,
"proj_network_type": "efficientnet",
},
"G": {
"attn_nb_mask_attn": 10,
"attn_nb_mask_input": 1,
"dropout": false,
"nblocks": 9,
"netG": "img2img_turbo",
"ngf": 64,
"norm": "instance",
"padding_type": "reflect",
},
"alg": {
"gan": {"lambda": 1.0},
"cut": {
"HDCE_gamma": 1.0,
"HDCE_gamma_min": 1.0,
"MSE_idt": false,
"flip_equivariance": false,
"lambda_MSE_idt": 1.0,
"lambda_NCE": 1.0,
"lambda_SRC": 0.0,
"nce_T": 0.07,
"nce_idt": true,
"nce_includes_all_negatives_from_minibatch": false,
"nce_layers": "0,4,8,12,16",
"nce_loss": "monce",
"netF": "mlp_sample",
"netF_dropout": false,
"netF_nc": 256,
"netF_norm": "instance",
"num_patches": 256,
},
},
"data": {
"crop_size": 256,
"dataset_mode": "unaligned",
"direction": "AtoB",
"load_size": 256,
"max_dataset_size": 1000000000,
"num_threads": 4,
"preprocess": "resize_and_crop",
},
"output": {
"display": {
"freq": 400,
"id": 1,
"ncols": 0,
"type": ["visdom"],
"visdom_port": 8097,
"visdom_server": "http://localhost",
"winsize": 256,
},
"no_html": false,
"print_freq": 100,
"update_html_freq": 1000,
"verbose": false,
},
"model": {
"init_gain": 0.02,
"init_type": "normal",
"input_nc": 3,
"multimodal": false,
"output_nc": 3,
},
"train": {
"D_lr": 0.0001,
"G_ema": false,
"G_ema_beta": 0.999,
"G_lr": 0.0001,
"batch_size": 4,
"beta1": 0.9,
"beta2": 0.999,
"continue": false,
"epoch": "latest",
"epoch_count": 1,
"export_jit": false,
"gan_mode": "lsgan",
"iter_size": 8,
"load_iter": 0,
"metrics_every": 1000,
"n_epochs": 200,
"n_epochs_decay": 100,
"nb_img_max_fid": 1000000000,
"optim": "adam",
"pool_size": 50,
"save_by_iter": false,
"save_epoch_freq": 1,
"save_latest_freq": 5000,
},
"dataaug": {
"APA": false,
"APA_every": 4,
"APA_nimg": 50,
"APA_p": 0,
"APA_target": 0.6,
"D_diffusion": false,
"D_diffusion_every": 4,
"D_label_smooth": false,
"D_noise": 0.0,
"affine": 0.0,
"affine_scale_max": 1.2,
"affine_scale_min": 0.8,
"affine_shear": 45,
"affine_translate": 0.2,
"diff_aug_policy": "",
"diff_aug_proba": 0.5,
"imgaug": false,
"no_flip": false,
"no_rotate": true,
},
"checkpoints_dir": "/path/to/checkpoints",
"dataroot": "/path/to/horse2zebra",
"ddp_port": "12355",
"gpu_ids": "0",
"model_type": "cut",
"name": "horse2zebra",
"phase": "train",
"test_batch_size": 1,
"warning_mode": false,
"with_amp": false,
"with_tf32": false,
"with_torch_compile": false,
}
32 changes: 25 additions & 7 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,8 @@ def __init__(self, opt, rank):
self.opt.model_input_nc += self.opt.train_mm_nz
self.netG_A = gan_networks.define_G(**vars(opt))

# XXX: early prompt support
# if self.opt.G_prompt:
# self.netG_A.prompt = self.opt.G_prompt
if self.opt.G_netG == "img2img_turbo" and self.opt.G_prompt:
self.prompt_opt = [self.opt.G_prompt] * opt.train_batch_size

self.netG_A.lora_rank_unet = self.opt.G_lora_unet
self.netG_A.lora_rank_vae = self.opt.G_lora_vae
Expand Down Expand Up @@ -556,7 +555,13 @@ def inference(self, nb_imgs, offset=0):
self.real_with_z = self.real

if self.opt.G_netG == "img2img_turbo":
self.fake = self.netG_A(self.real_with_z, self.real_B_prompt)
prompt = self.prompt_opt if self.opt.G_prompt else self.real_B_prompt
if len(prompt) != self.real_with_z.shape[0]:
prompt = prompt * (
self.real_with_z.shape[0] // self.opt.train_batch_size
)

self.fake = self.netG_A(self.real_with_z, prompt)
else:
self.fake = self.netG_A(self.real_with_z)

Expand Down Expand Up @@ -602,7 +607,13 @@ def forward_cut(self):
self.real_with_z = self.real

if self.opt.G_netG == "img2img_turbo":
self.fake = self.netG_A(self.real_with_z, self.real_B_prompt)
prompt = self.prompt_opt if self.opt.G_prompt else self.real_B_prompt
if len(prompt) != self.real_with_z.shape[0]:
prompt = prompt * (
self.real_with_z.shape[0] // self.opt.train_batch_size
)

self.fake = self.netG_A(self.real_with_z, prompt)
else:
self.fake = self.netG_A(self.real_with_z)

Expand Down Expand Up @@ -659,10 +670,17 @@ def forward_E(self):
self.real_A.size(3),
)
real_A_with_z = torch.cat([self.real_A, z_real], 1)

if self.opt.G_netG == "img2img_turbo":
fake_B = self.netG_A(real_A_with_z, self.real_B_prompt)
prompt = self.prompt_opt if self.opt.G_prompt else self.real_B_prompt
if len(prompt) != self.real_with_z.shape[0]:
prompt = prompt * (
self.real_with_z.shape[0] // self.opt.train_batch_size
)

fake_B = self.netG_A(self.real_with_z, prompt)
else:
fake_B = self.netG_A(real_A_with_z)
fake_B = self.netG_A(self.real_with_z)
self.mu2 = self.netE(fake_B)

def compute_G_loss_cut(self):
Expand Down
13 changes: 1 addition & 12 deletions models/modules/img2img_turbo/img2img_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,14 @@ def forward(self, x, prompt):
).input_ids.cuda()
caption_enc = self.text_encoder(caption_tokens)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about the [0] ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with refs:
1.https://huggingface.co/transformers/v4.8.0/model_doc/clip.html#flaxcliptextmodel
2.https://github.com/huggingface/transformers/blob/f91c16d270e5e3ff32fdb32ccf286d05c03dfa66/src/transformers/models/clip/modeling_clip.py#L759
"outputs= self.text_encoder(caption_tokens)"
type(outputs)= text_encoder <class 'transformers.modeling_outputs.BaseModelOutputWithPooling'>
len(outputs) = 2
outputs[0].shape = torch.Size([4, 77, 1024]) this is last_hidden_state
outputs[1].shape = torch.Size([4, 1024]) this is the pooler_output
According to the explication of refs 1, should be outputs[0].


# match batch size
captions_enc = caption_enc.repeat(x.shape[0], 1, 1)

# deterministic forward
encoded_control = (
self.vae.encode(x).latent_dist.sample() * self.vae.config.scaling_factor
)
model_pred = self.unet(
encoded_control,
self.timesteps,
encoder_hidden_states=captions_enc,
encoder_hidden_states=caption_enc,
).sample
x_denoised = self.sched.step(
model_pred, self.timesteps, encoded_control, return_dict=True
Expand All @@ -223,14 +220,6 @@ def forward(self, x, prompt):
return x

def compute_feats(self, input, extract_layer_ids=[]):
# caption_tokens = self.tokenizer(
# #self.prompt, # XXX: set externally
# prompt,
# max_length=self.tokenizer.model_max_length,
# padding="max_length",
# truncation=True,
# return_tensors="pt",
# ).input_ids.cuda()

# deterministic forward
encoded_control = (
Expand Down
12 changes: 6 additions & 6 deletions options/common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,12 @@ def initialize(self, parser):
help="Patch size for HDIT, e.g. 4 for 4x4 patches",
)

# parser.add_argument(
# "--G_prompt",
# type=str,
# default="",
# help="Text prompt for G",
# )
parser.add_argument(
"--G_prompt",
type=str,
default="",
help="Text prompt for G",
)
parser.add_argument(
"--G_lora_unet",
type=int,
Expand Down
Loading