Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[training] CogVideoX-I2V LoRA #9482

Merged
merged 12 commits into from
Oct 15, 2024
Merged

[training] CogVideoX-I2V LoRA #9482

merged 12 commits into from
Oct 15, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Sep 20, 2024

What does this PR do?

Image-to-Video LoRA finetuning for CogVideoX

bash script
#!/bin/bash

export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
export TORCHDYNAMO_VERBOSE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export WANDB_MODE="offline"

GPU_IDS="2,3"
LEARNING_RATES=("1e-3" "1e-4")
LR_SCHEDULES=("constant" "cosine_with_restarts")
OPTIMIZERS=("adam" "adamw")
EPOCHS=("30")

for learning_rate in "${LEARNING_RATES[@]}"; do
  for lr_schedule in "${LR_SCHEDULES[@]}"; do
    for optimizer in "${OPTIMIZERS[@]}"; do
      for epochs in "${EPOCHS[@]}"; do
        cache_dir="/raid/aryan/cogvideox-lora/"
        output_dir="/raid/aryan/cogvideox-lora__optimizer_${optimizer}__epochs_${epochs}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"

        cmd="accelerate launch --gpu_ids $GPU_IDS --config_file accelerate_configs/simple_uncompiled_v2.yaml examples/cogvideo/train_cogvideox_image_to_video_lora.py \
          --pretrained_model_name_or_path /raid/aryan/CogVideoX-5b-I2V/ \
          --cache_dir $cache_dir \
          --instance_data_root /raid/aryan/dataset-cogvideox/ \
          --caption_column <CAPTION_FILENAME> \
          --video_column <VIDEOS_FILENAME> \
          --validation_prompt \"A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
          --validation_images \"/raid/aryan/dataset-cogvideox/videos/frames_1_00.png:::/raid/aryan/dataset-cogvideox/videos/frames_2_00.png\" \
          --validation_prompt_separator ::: \
          --num_validation_videos 1 \
          --validation_epochs 10 \
          --seed 42 \
          --rank 64 \
          --lora_alpha 64 \
          --mixed_precision bf16 \
          --output_dir $output_dir \
          --height 480 --width 720 --fps 8 --max_num_frames 49 --skip_frames_start 0 --skip_frames_end 0 \
          --train_batch_size 1 \
          --num_train_epochs $epochs \
          --checkpointing_steps 10000 \
          --gradient_accumulation_steps 1 \
          --learning_rate $learning_rate \
          --lr_scheduler $lr_schedule \
          --lr_warmup_steps 200 \
          --lr_num_cycles 1 \
          --enable_slicing \
          --enable_tiling \
          --gradient_checkpointing \
          --optimizer $optimizer \
          --adam_beta1 0.9 \
          --adam_beta2 0.95 \
          --max_grad_norm 1.0 \
          --report_to wandb \
          --nccl_timeout 1800"
        
        echo "Running command: $cmd"
        eval $cmd
        echo -ne "-------------------- Finished executing script --------------------"
      done
    done
  done
done

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SHYuanBest
Copy link
Contributor

When will the SFT version of CogVideoX-5B T2V be available?

@a-r-r-o-w
Copy link
Member Author

When will the SFT version of CogVideoX-5B T2V be available?

We can add support for it soon. The changes to make CogVideoX T2V LoRA to full SFT should be very simple actually - you will have to remove all the lora related parts and make the transformer require gradients, and instead of saving lora weights, you can save the full model with pipe.save_pretrained()

@963658029
Copy link
Contributor

Why didn't you run the following two lines of code after calculating the loss?
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps

@a-r-r-o-w
Copy link
Member Author

Gradient accumulation is handled by accelerate (you can see that we pass gradient_accumulation_steps when initializing Accelerator). We don't need to maintain total train loss because accelerator.backward(loss) implicitly maintains it for N - 1 gradient steps and performs averaging at N'th gradient accumulation step. Here's a helpful reference.

@963658029
Copy link
Contributor

But "avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()" is not to gather the losses across all processes for logging (if we use distributed training)? Will it automatically gather even if this line of code is not used?

@963658029
Copy link
Contributor

963658029 commented Sep 24, 2024

I'm not sure, but I've seen other models' training scripts run these line of code. (for distributed training)

@a-r-r-o-w
Copy link
Member Author

Hmm, I'm not too sure either what the case for distributed training should be. It seems like you might be right because I took a look at a few code bases and found this used too.

Just for a sanity check, pinging @SunMarc here. Do we need an accelerator.gather(loss) when using distributed training or is it handled internally? In the script, I only have a call to accelerator.backward(loss) and, as far as I've understood from the accelerate codebase, it looks like gradient accumulation in distributed settings is already considered

@963658029
Copy link
Contributor

any update?

@lclichen
Copy link

As far as I know, the gather function in torch only collects the numerical results of loss (excluding gradients) for logging, and the accumulation of gradients is done automatically in the backward function.

I am not sure if I am right, if there is a mistake please help me correct it.

@963658029
Copy link
Contributor

but if accelerator.gather(loss) (e.g., loss) is incorrect, will accelerator.backward(loss) (e.g., backword) correct?

@963658029
Copy link
Contributor

any update?

@a-r-r-o-w
Copy link
Member Author

We're working on a separate repository for memory-efficient and multiresolution finetuning of CogVideoX that will be open-sourced soon. This PR will probably not receive any further updates at the moment.

I think accelerator.backward(loss) is correct, and I think what @lclichen said makes the most sense to me but please correct me if I'm wrong so I can make all the necessary changes. Pinging @muellerzr, @SunMarc from the accelerate team and @sayakpaul to verify this.

@963658029
Copy link
Contributor

Can the code of this PR run normally to fine-tune i2v (consistent with the official effect of CogVideoX)?

@963658029
Copy link
Contributor

No need for multi-resolution and memory-efficient, is this code now the same as sat?

@a-r-r-o-w
Copy link
Member Author

The changes for I2V LoRA are as follows (in SAT):

  • Adding small amounts of noise to image before VAE encode: here. This is supported in current training script
  • Noised image dropout (replacing image condition with zeros): here. This is supported too
  • Loss with denoised video latent (image latent is never considered in loss, nor while adding initial noise): here. This is considered too in the script.

Apart from these, if I'm missing anything, please let me know or feel free to open a PR for improvements. Maybe I can merge this as it is for now, and we can work on the others improvements later (it will be released as a separate repository in the near future).

Here are my training runs: wandb

Accelerate Config
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: 2,3
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Launch script
#!/bin/bash

export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
export TORCHDYNAMO_VERBOSE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export WANDB_MODE="offline"

GPU_IDS="2,3"
LEARNING_RATES=("1e-3" "1e-4")
LR_SCHEDULES=("constant" "cosine_with_restarts")
OPTIMIZERS=("adam" "adamw")
EPOCHS=("30")

for learning_rate in "${LEARNING_RATES[@]}"; do
  for lr_schedule in "${LR_SCHEDULES[@]}"; do
    for optimizer in "${OPTIMIZERS[@]}"; do
      for epochs in "${EPOCHS[@]}"; do
        cache_dir="/raid/aryan/cogvideox-lora/"
        output_dir="/raid/aryan/cogvideox-lora__optimizer_${optimizer}__epochs_${epochs}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
        tracker_name="cogvideox-lora__optimizer_${optimizer}__epochs_${epochs}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}"

        cmd="accelerate launch --gpu_ids $GPU_IDS --config_file accelerate_configs/simple_uncompiled_v2.yaml examples/cogvideo/train_cogvideox_image_to_video_lora.py \
          --pretrained_model_name_or_path /raid/aryan/CogVideoX-5b-I2V/ \
          --cache_dir $cache_dir \
          --instance_data_root /raid/aryan/video-dataset-disney/ \
          --caption_column prompts.txt \
          --video_column videos.txt \
          --validation_prompt \"A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
          --validation_images \"/raid/aryan/dataset-cogvideox/videos/frames_1_00.png:::/raid/aryan/dataset-cogvideox/videos/frames_2_00.png\" \
          --validation_prompt_separator ::: \
          --num_validation_videos 1 \
          --validation_epochs 10 \
          --seed 42 \
          --rank 64 \
          --lora_alpha 64 \
          --mixed_precision bf16 \
          --output_dir $output_dir \
          --height 480 --width 720 --fps 8 --max_num_frames 49 --skip_frames_start 0 --skip_frames_end 0 \
          --train_batch_size 1 \
          --num_train_epochs $epochs \
          --checkpointing_steps 10000 \
          --gradient_accumulation_steps 1 \
          --learning_rate $learning_rate \
          --lr_scheduler $lr_schedule \
          --lr_warmup_steps 200 \
          --lr_num_cycles 1 \
          --enable_slicing \
          --enable_tiling \
          --gradient_checkpointing \
          --optimizer $optimizer \
          --adam_beta1 0.9 \
          --adam_beta2 0.95 \
          --max_grad_norm 1.0 \
          --report_to wandb \
          --nccl_timeout 1800"
        
        echo "Running command: $cmd"
        eval $cmd
        echo -ne "-------------------- Finished executing script --------------------"
      done
    done
  done
done

@Florenyci
Copy link

The changes for I2V LoRA are as follows (in SAT):

  • Adding small amounts of noise to image before VAE encode: here. This is supported in current training script
  • Noised image dropout (replacing image condition with zeros): here. This is supported too
  • Loss with denoised video latent (image latent is never considered in loss, nor while adding initial noise): here. This is considered too in the script.

Apart from these, if I'm missing anything, please let me know or feel free to open a PR for improvements. Maybe I can merge this as it is for now, and we can work on the others improvements later (it will be released as a separate repository in the near future).

Here are my training runs: wandb

Accelerate Config
Launch script

@a-r-r-o-w Great work! I also want to fine-tune VAE model, is there any code pointer to do it?

@963658029
Copy link
Contributor

u missed this line of code image = image + image_noise in def encode_video()

The changes for I2V LoRA are as follows (in SAT):

  • Adding small amounts of noise to image before VAE encode: here. This is supported in current training script
  • Noised image dropout (replacing image condition with zeros): here. This is supported too
  • Loss with denoised video latent (image latent is never considered in loss, nor while adding initial noise): here. This is considered too in the script.

Apart from these, if I'm missing anything, please let me know or feel free to open a PR for improvements. Maybe I can merge this as it is for now, and we can work on the others improvements later (it will be released as a separate repository in the near future).

Here are my training runs: wandb

Accelerate Config
Launch script

I found a big problem, you missed this line of code image = image + image_noise in def encode_video()

@a-r-r-o-w
Copy link
Member Author

I found a big problem, you missed this line of code image = image + image_noise in def encode_video()

I'm extremely sorry for the time this would have cost you! I actually have this fixed in the latest repository that we will be releasing for finetuning cogvideox (hopefully this week as only some final tests are remaining), but I completely forgot to push the latest changes here 🫠 Thank you so much for reporting this though! Would you be okay if I added you as a co-author when pushing the fix here?

Great work! I also want to fine-tune VAE model, is there any code pointer to do it?

We haven't looked into yet but since there have been many requests (on the original CogVideo repo), we might consider providing a script for finetuning VAE too soon. cc @zRzRzRzRzRzRzR

@963658029
Copy link
Contributor

Thanks for your reply. @SHYuanBest informed me about this bug, and I would be delighted if he could also be listed as a co-author when you push the fix.

@963658029
Copy link
Contributor

963658029 commented Oct 7, 2024

It seems that deepspeed cannot run normally:
1.
use_deepspeed_scheduler = ( accelerator.state.deepspeed_plugin is not None and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config ) needs to be modified to
use_deepspeed_scheduler = ( accelerator.state.deepspeed_plugin is not None and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config )
2.
if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: needs to be modified to
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process : if global_step % args.checkpointing_steps == 0:
And we also found that the effect of the diffuser version of cogvideox is slightly worse than that of SAT, I don't know why. Any other bugs you have fixed expect these?

a-r-r-o-w and others added 3 commits October 8, 2024 09:27
Co-Authored-By: yuan-shenghai <[email protected]>
Co-Authored-By: Shenghai Yuan <[email protected]>
@a-r-r-o-w
Copy link
Member Author

And we also found that the effect of the diffuser version of cogvideox is slightly worse than that of SAT, I don't know why. Any other bugs you have fixed expect these?

Nope, the only bug we had fixed was the image + noisy_image part in our other repository too. Thanks for the DeepSpeed fixes! I'll try to investigate the SAT training run to see why the quality is slightly poorer in comparison

@a-r-r-o-w
Copy link
Member Author

Are you sure that if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process is correct? This would mean that other processes will also be performing validation/inference instead of the just the main, when using DeepSpeed? I've pushed it for now, but can try and verify later as I'm occupied with a few other things

@963658029
Copy link
Contributor

963658029 commented Oct 8, 2024

Are you sure that if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process is correct? This would mean that other processes will also be performing validation/inference instead of the just the main, when using DeepSpeed? I've pushed it for now, but can try and verify later as I'm occupied with a few other things

DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process should be used to save weights (line 1464), and line1497 (performing validation/inference) only needs if accelerator.is_main_process. The code just pushed was changed to the wrong position.

@a-r-r-o-w
Copy link
Member Author

Oh, I see what you mean, sorry! Will fix

Co-Authored-By: yuan-shenghai <[email protected]>
@963658029
Copy link
Contributor

963658029 commented Oct 8, 2024

And we also found that the effect of the diffuser version of cogvideox is slightly worse than that of SAT, I don't know why. Any other bugs you have fixed expect these?

Nope, the only bug we had fixed was the image + noisy_image part in our other repository too. Thanks for the DeepSpeed fixes! I'll try to investigate the SAT training run to see why the quality is slightly poorer in comparison

I have located the problem. I seriously suspect that there is a problem in the implementation of rescale_betas_zero_snr in the diffusers scheduler. This leads to the following abnormal phenomenon: the inference of cogvideox-diffusers gradually gets worse with the increase of steps, while the inference of cogvideox-sat gradually gets better with the increase of steps. I think this may also lead to the training effect being worse than that of sat? (when rescale_betas_zero_snr = True)

@SunMarc
Copy link
Member

SunMarc commented Oct 8, 2024

Just for a sanity check, pinging @SunMarc here. Do we need an accelerator.gather(loss) when using distributed training or is it handled internally? In the script, I only have a call to accelerator.backward(loss) and, as far as I've understood from the accelerate codebase, it looks like gradient accumulation in distributed settings is already considered

Sorry for the wait ! I see that you are going to merge this soon. You don't have to gather the loss before calling accelerator.backward(loss). As you said gradient accumulation is already considered and we don't really want to gather the loss actually. More explanation on DDP particular case here.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review October 15, 2024 00:47
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@a-r-r-o-w a-r-r-o-w merged commit 2ffbb88 into main Oct 15, 2024
18 checks passed
@a-r-r-o-w a-r-r-o-w deleted the cogvideox/i2v-finetuning branch October 15, 2024 20:37
@scarbain
Copy link

Hi everyone. I just trained a LoRA for I2V thanks to all your work but now I'm having trouble using it for inference.
What should I use to add the LoRA to my I2V pipeline ? I saw the readme for inferencing a LoRA on T2V but it doesn't work on I2V, I get the error "AttributeError: 'CogVideoXImageToVideoPipeline' object has no attribute 'load_lora_weights'"

@a-r-r-o-w
Copy link
Member Author

You will have to install the latest diffusers release with pip install -U diffusers, or the main branch for being able to use load_lora_weights

@Luo-Yihong
Copy link

The changes for I2V LoRA are as follows (in SAT):

  • Adding small amounts of noise to image before VAE encode: here. This is supported in current training script
  • Noised image dropout (replacing image condition with zeros): here. This is supported too
  • Loss with denoised video latent (image latent is never considered in loss, nor while adding initial noise): here. This is considered too in the script.

Apart from these, if I'm missing anything, please let me know or feel free to open a PR for improvements. Maybe I can merge this as it is for now, and we can work on the others improvements later (it will be released as a separate repository in the near future).

Here are my training runs: wandb

Accelerate Config
Launch script

May I ask about how much GPU memory is required to use this LoRA fine-tuning script?

@a-r-r-o-w
Copy link
Member Author

You could do it in less than 24gb for batch_size=1. More details available here: https://github.com/a-r-r-o-w/cogvideox-factory

@SHYuanBest
Copy link
Contributor

SHYuanBest commented Nov 27, 2024

@a-r-r-o-w hi, we just develope an identity-preserving text-to-video generation model, ConsisID (base on CogVideoX-5B), which can keep human-identity consistent in the generated video. Can you help us to intergrate it into diffusers? Thanks.

https://github.com/PKU-YuanGroup/ConsisID

@a-r-r-o-w
Copy link
Member Author

Ofcourse, I would love to! For this week, I'm quite busy but I will take a good look and start testing it this weekend. Thanks for the awesome work!

@SHYuanBest
Copy link
Contributor

@a-r-r-o-w hi arrow, if you need any help, just feel free to let me know

@a-r-r-o-w
Copy link
Member Author

@SHYuanBest Sorry, I was not able to find the time yet to try and integrate this. But I did try it out on ComfyUI and the results were very cool! If you could open a PR, we could help with reviews and try to integrate it faster. If that works, we can set up a communication channel on Slack with your team (cc @sayakpaul)

@SHYuanBest
Copy link
Contributor

SHYuanBest commented Dec 6, 2024

sure, i have create a PR here #10140. I will update the code here as soon as possible.

sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* update

* update

* update

* update

* update

* add coauthor

Co-Authored-By: yuan-shenghai <[email protected]>

* add coauthor

Co-Authored-By: Shenghai Yuan <[email protected]>

* update

Co-Authored-By: yuan-shenghai <[email protected]>

* update

---------

Co-authored-by: yuan-shenghai <[email protected]>
Co-authored-by: Shenghai Yuan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.