Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

G_prompt for cut_turbo for dataset with single prompt #662

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

wr0124
Copy link
Collaborator

@wr0124 wr0124 commented Jun 20, 2024

add G_prompt for cut_turbo for unaligned dataset and works for batch_size larger than 1

  • inference
  • unit tests
  • documentation

The training works with the following command line

python3 train.py
--dataroot /data1/juliew/dataset/horse2zebra
--checkpoints_dir /data1/juliew/checkpoints
--name horse2zebra_turbo
--config_json examples/example_cut_turbo_horse2zebra.json
--train_batch_size 2
--output_print_freq 10
--data_crop_size 64
--data_load_size 64
--G_prompt zebra (this option is mandatory if there is no prompt file in the dataset)

The inference works with the following command line

cd scripts
python3 gen_single_image.py
--model_in_file /data1/juliew/checkpoints/horse2zebra_turbo/latest_net_G_A.pth
--img_in /data1/juliew/dataset/horse2zebra/testA/n02381460_1000.jpg
--img_out /data1/juliew/target.jpg
--prompt zebra
--gpuid 0 \

@beniz beniz changed the title G_prompt for cut_turbo for unaligned horse2zebra dataset G_prompt for cut_turbo for dataset with single prompts Jun 21, 2024
else:
fake_B = self.netG_A(real_A_with_z)
self.fake_B = self.netG_A(self.real_with_z)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove self

# match batch size
captions_enc = caption_enc.repeat(x.shape[0], 1, 1)
batch_size = caption_enc.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unneeded ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inside cut_model, x is created by cat real_A and real_B, which double the batch size of the x tensor. prompt tensor has normal batch_size, so, to match the two tensor, I did this modification. Detail toy example is here: https://colab.research.google.com/drive/1RMvHt2PuQufH4zEc2Lrds561L9NEYzYf?usp=sharing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed by modifying the prompt tensor outside turbo, in cut when A & B are concatenated for inference, not here.

"D_lr": 0.0001,
"G_ema": false,
"G_ema_beta": 0.999,
"G_lr": 0.0002,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set to 0.0001

@beniz beniz changed the title G_prompt for cut_turbo for dataset with single prompts G_prompt for cut_turbo for dataset with single prompt Jun 26, 2024
@@ -201,17 +201,14 @@ def forward(self, x, prompt):
).input_ids.cuda()
caption_enc = self.text_encoder(caption_tokens)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about the [0] ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with refs:
1.https://huggingface.co/transformers/v4.8.0/model_doc/clip.html#flaxcliptextmodel
2.https://github.com/huggingface/transformers/blob/f91c16d270e5e3ff32fdb32ccf286d05c03dfa66/src/transformers/models/clip/modeling_clip.py#L759
"outputs= self.text_encoder(caption_tokens)"
type(outputs)= text_encoder <class 'transformers.modeling_outputs.BaseModelOutputWithPooling'>
len(outputs) = 2
outputs[0].shape = torch.Size([4, 77, 1024]) this is last_hidden_state
outputs[1].shape = torch.Size([4, 1024]) this is the pooler_output
According to the explication of refs 1, should be outputs[0].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants