From 02cbe972c3013a596c422b8cc0ca1e872f2eb647 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Dec 2024 08:51:55 +0530 Subject: [PATCH 01/88] [Tests] update always test pipelines list. (#10143) update always test pipelines list. --- utils/fetch_torch_cuda_pipeline_test_matrix.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/utils/fetch_torch_cuda_pipeline_test_matrix.py b/utils/fetch_torch_cuda_pipeline_test_matrix.py index e6a9c4b6a3bd..227a60bc596f 100644 --- a/utils/fetch_torch_cuda_pipeline_test_matrix.py +++ b/utils/fetch_torch_cuda_pipeline_test_matrix.py @@ -16,12 +16,8 @@ "stable_diffusion_2", "stable_diffusion_xl", "stable_diffusion_adapter", - "deepfloyd_if", "ip_adapters", - "kandinsky", "kandinsky2_2", - "text_to_video_synthesis", - "wuerstchen", ] PIPELINE_USAGE_CUTOFF = int(os.getenv("PIPELINE_USAGE_CUTOFF", 50000)) From 3bf5400a64c847e070f332aed7d7a56d89bb22e3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Dec 2024 10:26:06 +0530 Subject: [PATCH 02/88] Update sana.md with minor corrections (#10232) --- docs/source/en/api/pipelines/sana.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/sana.md b/docs/source/en/api/pipelines/sana.md index f65faf46c2b9..1894aa55757e 100644 --- a/docs/source/en/api/pipelines/sana.md +++ b/docs/source/en/api/pipelines/sana.md @@ -26,7 +26,7 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m -This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model]https://huggingface.co/Efficient-Large-Model). +This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj) and [chenjy2003](https://github.com/chenjy2003). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model). Available models: From e68092a4718a775568fae009e50162425eefbb1e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Dec 2024 12:24:14 +0530 Subject: [PATCH 03/88] [docs] minor stuff to ltx video docs. (#10229) minor stuff to ltx video docs. --- docs/source/en/api/pipelines/ltx_video.md | 24 +++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index 162e1334ce9a..ac2b1c95b5b1 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -31,14 +31,18 @@ import torch from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" -transformer = LTXVideoTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16) +transformer = LTXVideoTransformer3DModel.from_single_file( + single_file_url, torch_dtype=torch.bfloat16 +) vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16) -pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16) +pipe = LTXImageToVideoPipeline.from_pretrained( + "Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16 +) # ... inference code ... ``` -Alternatively, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`]. +Alternatively, the pipeline can be used to load the weights with [`~FromSingleFileMixin.from_single_file`]. ```python import torch @@ -46,11 +50,19 @@ from diffusers import LTXImageToVideoPipeline from transformers import T5EncoderModel, T5Tokenizer single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" -text_encoder = T5EncoderModel.from_pretrained("Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16) -tokenizer = T5Tokenizer.from_pretrained("Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16) -pipe = LTXImageToVideoPipeline.from_single_file(single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16) +text_encoder = T5EncoderModel.from_pretrained( + "Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16 +) +tokenizer = T5Tokenizer.from_pretrained( + "Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16 +) +pipe = LTXImageToVideoPipeline.from_single_file( + single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16 +) ``` +Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption. + ## LTXPipeline [[autodoc]] LTXPipeline From 8957324363d8b239d82db4909fbf8c0875683e3d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 16 Dec 2024 12:28:36 +0530 Subject: [PATCH 04/88] Fix format issue in push_test yml (#10235) update --- .github/workflows/push_tests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 055c282e7c1e..cc0cd3da0218 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -165,7 +165,8 @@ jobs: group: gcp-ct5lp-hightpu-8t container: image: diffusers/diffusers-flax-tpu - options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache defaults: + options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache + defaults: run: shell: bash steps: From aace1f412bc41f521b699a3228f4ec3339160c98 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 16 Dec 2024 13:56:18 +0530 Subject: [PATCH 05/88] [core] Hunyuan Video (#10136) * copy transformer * copy vae * copy pipeline * make fix-copies * refactor; make original code work with diffusers; test latents for comparison generated with this commit * move rope into pipeline; remove flash attention; refactor * begin conversion script * make style * refactor attention * refactor * refactor final layer * their mlp -> our feedforward * make style * add docs * refactor layer names * refactor modulation * cleanup * refactor norms * refactor activations * refactor single blocks attention * refactor attention processor * make style * cleanup a bit * refactor double transformer block attention * update mochi attn proc * use diffusers attention implementation in all modules; checkpoint for all values matching original * remove helper functions in vae * refactor upsample * refactor causal conv * refactor resnet * refactor * refactor * refactor * grad checkpointing * autoencoder test * fix scaling factor * refactor clip * refactor llama text encoding * add coauthor Co-Authored-By: "Gregory D. Hunkins" * refactor rope; diff: 0.14990234375; reason and fix: create rope grid on cpu and move to device Note: The following line diverges from original behaviour. We create the grid on the device, whereas original implementation creates it on CPU and then moves it to device. This results in numerical differences in layerwise debugging outputs, but visually it is the same. * use diffusers timesteps embedding; diff: 0.10205078125 * rename * convert * update * add tests for transformer * add pipeline tests; text encoder 2 is not optional * fix attention implementation for torch * add example * update docs * update docs * apply suggestions from review * refactor vae * update * Apply suggestions from code review Co-authored-by: hlky * Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py Co-authored-by: hlky * Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py Co-authored-by: hlky * make fix-copies * update --------- Co-authored-by: "Gregory D. Hunkins" Co-authored-by: hlky --- docs/source/en/_toctree.yml | 6 + .../models/autoencoder_kl_hunyuan_video.md | 32 + .../models/hunyuan_video_transformer_3d.md | 30 + docs/source/en/api/pipelines/hunyuan_video.md | 43 + scripts/convert_hunyuan_video_to_diffusers.py | 257 ++++ src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 4 + src/diffusers/models/activations.py | 12 + src/diffusers/models/attention.py | 4 +- src/diffusers/models/attention_processor.py | 16 +- src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoder_kl_hunyuan_video.py | 1175 +++++++++++++++++ src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_hunyuan_video.py | 723 ++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../pipelines/hunyuan_video/__init__.py | 48 + .../hunyuan_video/pipeline_hunyuan_video.py | 675 ++++++++++ .../hunyuan_video/pipeline_output.py | 20 + src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_autoencoder_hunyuan_video.py | 159 +++ .../test_models_transformer_hunyuan_video.py | 89 ++ tests/pipelines/hunyuan_video/__init__.py | 0 .../hunyuan_video/test_hunyuan_video.py | 331 +++++ 24 files changed, 3676 insertions(+), 3 deletions(-) create mode 100644 docs/source/en/api/models/autoencoder_kl_hunyuan_video.md create mode 100644 docs/source/en/api/models/hunyuan_video_transformer_3d.md create mode 100644 docs/source/en/api/pipelines/hunyuan_video.md create mode 100644 scripts/convert_hunyuan_video_to_diffusers.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py create mode 100644 src/diffusers/models/transformers/transformer_hunyuan_video.py create mode 100644 src/diffusers/pipelines/hunyuan_video/__init__.py create mode 100644 src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py create mode 100644 src/diffusers/pipelines/hunyuan_video/pipeline_output.py create mode 100644 tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py create mode 100644 tests/models/transformers/test_models_transformer_hunyuan_video.py create mode 100644 tests/pipelines/hunyuan_video/__init__.py create mode 100644 tests/pipelines/hunyuan_video/test_hunyuan_video.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f4eb32cf63a8..d1404a1d6ea6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -270,6 +270,8 @@ title: FluxTransformer2DModel - local: api/models/hunyuan_transformer2d title: HunyuanDiT2DModel + - local: api/models/hunyuan_video_transformer_3d + title: HunyuanVideoTransformer3DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d @@ -316,6 +318,8 @@ title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX + - local: api/models/autoencoder_kl_hunyuan_video + title: AutoencoderKLHunyuanVideo - local: api/models/autoencoderkl_ltx_video title: AutoencoderKLLTXVideo - local: api/models/autoencoderkl_mochi @@ -394,6 +398,8 @@ title: Flux - local: api/pipelines/hunyuandit title: Hunyuan-DiT + - local: api/pipelines/hunyuan_video + title: HunyuanVideo - local: api/pipelines/i2vgenxl title: I2VGen-XL - local: api/pipelines/pix2pix diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md new file mode 100644 index 000000000000..f69c14814d3d --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md @@ -0,0 +1,32 @@ + + +# AutoencoderKLHunyuanVideo + +The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLHunyuanVideo + +vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16) +``` + +## AutoencoderKLHunyuanVideo + +[[autodoc]] AutoencoderKLHunyuanVideo + - decode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/hunyuan_video_transformer_3d.md b/docs/source/en/api/models/hunyuan_video_transformer_3d.md new file mode 100644 index 000000000000..73aea9832fc0 --- /dev/null +++ b/docs/source/en/api/models/hunyuan_video_transformer_3d.md @@ -0,0 +1,30 @@ + + +# HunyuanVideoTransformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent. + +The model can be loaded with the following code snippet. + +```python +from diffusers import HunyuanVideoTransformer3DModel + +transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16) +``` + +## HunyuanVideoTransformer3DModel + +[[autodoc]] HunyuanVideoTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md new file mode 100644 index 000000000000..86ef816fcd4d --- /dev/null +++ b/docs/source/en/api/pipelines/hunyuan_video.md @@ -0,0 +1,43 @@ + + +# HunyuanVideo + +[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent. + +*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +Recommendations for inference: +- Both text encoders should be in `torch.float16`. +- Transformer should be in `torch.bfloat16`. +- VAE should be in `torch.float16`. +- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`. +- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo. +- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/). + +## HunyuanVideoPipeline + +[[autodoc]] HunyuanVideoPipeline + - all + - __call__ + +## HunyuanVideoPipelineOutput + +[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py new file mode 100644 index 000000000000..464c9e0fb954 --- /dev/null +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -0,0 +1,257 @@ +import argparse +from typing import Any, Dict + +import torch +from accelerate import init_empty_weights +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) + + +def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + +def remap_txt_in_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + +def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + +def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + +def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3072 + + if "linear1.weight" in key: + linear1_weight = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") + state_dict[f"{new_key}.attn.to_q.weight"] = q + state_dict[f"{new_key}.attn.to_k.weight"] = k + state_dict[f"{new_key}.attn.to_v.weight"] = v + state_dict[f"{new_key}.proj_mlp.weight"] = mlp + + elif "linear1.bias" in key: + linear1_bias = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + +TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", + "img_mod.linear": "norm1.linear", + "img_norm1": "norm1.norm", + "img_norm2": "norm2", + "img_mlp": "ff", + "txt_mod.linear": "norm1_context.linear", + "txt_norm1": "norm1.norm", + "txt_norm2": "norm2_context", + "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", + "modulation.linear": "norm.linear", + "pre_norm": "norm.norm", + "final_layer.norm_final": "norm_out.norm", + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2", + "input_embedder": "proj_in", +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, + "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, +} + +VAE_KEYS_RENAME_DICT = {} + +VAE_SPECIAL_KEYS_REMAP = {} + + +def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def convert_transformer(ckpt_path: str): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) + + with init_empty_weights(): + transformer = HunyuanVideoTransformer3DModel() + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + + +def convert_vae(ckpt_path: str): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) + + with init_empty_weights(): + vae = AutoencoderKLHunyuanVideo() + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") + parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") + parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint") + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +if __name__ == "__main__": + args = get_args() + + transformer = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + assert args.tokenizer_path is not None + assert args.text_encoder_2_path is not None + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.save_pipeline: + text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") + text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) + tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + pipe = HunyuanVideoPipeline( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 20914442b84a..dfa7a4df2d08 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -84,6 +84,7 @@ "AutoencoderKL", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", + "AutoencoderKLHunyuanVideo", "AutoencoderKLLTXVideo", "AutoencoderKLMochi", "AutoencoderKLTemporalDecoder", @@ -102,6 +103,7 @@ "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", + "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", "Kandinsky3UNet", "LatteTransformer3DModel", @@ -287,6 +289,7 @@ "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", + "HunyuanVideoPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -590,6 +593,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLHunyuanVideo, AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, @@ -608,6 +612,7 @@ HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, + HunyuanVideoTransformer3DModel, I2VGenXLUNet, Kandinsky3UNet, LatteTransformer3DModel, @@ -772,6 +777,7 @@ HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, + HunyuanVideoPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 687c555e0ce2..01e67b01d91a 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -31,6 +31,7 @@ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] + _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] @@ -67,6 +68,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -97,6 +99,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLHunyuanVideo, AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, @@ -130,6 +133,7 @@ DualTransformer2DModel, FluxTransformer2DModel, HunyuanDiT2DModel, + HunyuanVideoTransformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, LuminaNextDiT2DModel, diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index c1d4f0b46e15..c61baefa08f4 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -164,3 +164,15 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x * torch.sigmoid(1.702 * x) + + +class LinearActivation(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.activation = get_activation(activation) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + return self.activation(hidden_states) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02ed1f965abf..6749c7f17254 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -19,7 +19,7 @@ from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX @@ -1222,6 +1222,8 @@ def __init__( act_fn = ApproximateGELU(dim, inner_dim, bias=bias) elif activation_fn == "swiglu": act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") self.net = nn.ModuleList([]) # project in diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 77e35364ab09..ee6b010519e2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -254,14 +254,22 @@ def __init__( self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None if not self.pre_only: self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "fp32_layer_norm": @@ -782,7 +790,11 @@ def fuse_projections(self, fuse=True): self.to_kv.bias.copy_(concatenated_bias) # handle added projections for SD3 and others. - if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): concatenated_weights = torch.cat( [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] ) @@ -3938,7 +3950,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): # dropout hidden_states = attn.to_out[1](hidden_states) - if hasattr(attn, "to_add_out"): + if getattr(attn, "to_add_out", None) is not None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index d08e67c40975..bb750a4410f2 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -3,6 +3,7 @@ from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py new file mode 100644 index 000000000000..bded90a8bcff --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -0,0 +1,1175 @@ +# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..attention_processor import Attention +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_causal_attention_mask( + num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None +) -> torch.Tensor: + seq_len = num_frames * height_width + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // height_width + mask[i, : (i_frame + 1) * height_width] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + +class HunyuanVideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + return self.conv(hidden_states) + + +class HunyuanVideoUpsampleCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + kernel_size: int = 3, + stride: int = 1, + bias: bool = True, + upsample_factor: Tuple[float, float, float] = (2, 2, 2), + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + self.upsample_factor = upsample_factor + + self.conv = HunyuanVideoCausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_frames = hidden_states.size(2) + + first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2) + first_frame = F.interpolate( + first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest" + ).unsqueeze(2) + + if num_frames > 1: + # See: https://github.com/pytorch/pytorch/issues/81665 + # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate + # is fixed, this will raise either a runtime error, or fail silently with bad outputs. + # If you are encountering an error here, make sure to try running encoding/decoding with + # `vae.enable_tiling()` first. If that doesn't work, open an issue at: + # https://github.com/huggingface/diffusers/issues + other_frames = other_frames.contiguous() + other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest") + hidden_states = torch.cat((first_frame, other_frames), dim=2) + else: + hidden_states = first_frame + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoDownsampleCausal3D(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + padding: int = 1, + kernel_size: int = 3, + bias: bool = True, + stride=2, + ) -> None: + super().__init__() + out_channels = out_channels or channels + + self.conv = HunyuanVideoCausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoResnetBlockCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True) + self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0) + + self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True) + self.dropout = nn.Dropout(dropout) + self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + hidden_states = hidden_states + residual + return hidden_states + + +class HunyuanVideoMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_attention: bool = True, + attention_head_dim: int = 1, + ) -> None: + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ] + attentions = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + + else: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanVideoDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_stride: int = 2, + downsample_padding: int = 1, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + HunyuanVideoDownsampleCausal3D( + out_channels, + out_channels=out_channels, + padding=downsample_padding, + stride=downsample_stride, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for resnet in self.resnets: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_upsample: bool = True, + upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2), + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + HunyuanVideoUpsampleCausal3D( + out_channels, + out_channels=out_channels, + upsample_factor=upsample_scale_factor, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for resnet in self.resnets: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoEncoder3D(nn.Module): + r""" + Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ) -> None: + super().__init__() + + self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + if down_block_type != "HunyuanVideoDownBlock3D": + raise ValueError(f"Unsupported down_block_type: {down_block_type}") + + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_downsample_layers = int(np.log2(temporal_compression_ratio)) + + if temporal_compression_ratio == 4: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool( + i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block + ) + elif temporal_compression_ratio == 8: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool(i < num_time_downsample_layers) + else: + raise ValueError(f"Unsupported time_compression_ratio: {temporal_compression_ratio}") + + downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) + downsample_stride_T = (2,) if add_time_downsample else (1,) + downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) + + down_block = HunyuanVideoDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=bool(add_spatial_downsample or add_time_downsample), + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample_stride=downsample_stride, + downsample_padding=0, + ) + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for down_block in self.down_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), hidden_states, **ckpt_kwargs + ) + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs + ) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class HunyuanVideoDecoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + mid_block_add_attention=True, + time_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + if up_block_type != "HunyuanVideoUpBlock3D": + raise ValueError(f"Unsupported up_block_type: {up_block_type}") + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_upsample_layers = int(np.log2(time_compression_ratio)) + + if time_compression_ratio == 4: + add_spatial_upsample = bool(i < num_spatial_upsample_layers) + add_time_upsample = bool( + i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block + ) + else: + raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}") + + upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) + upsample_scale_factor_T = (2,) if add_time_upsample else (1,) + upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) + + up_block = HunyuanVideoUpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=bool(add_spatial_upsample or add_time_upsample), + upsample_scale_factor=upsample_scale_factor, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[0], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs + ) + + for up_block in self.up_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), hidden_states, **ckpt_kwargs + ) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 16, + down_block_types: Tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels: Tuple[int] = (128, 256, 512, 512), + layers_per_block: int = 2, + act_fn: str = "silu", + norm_num_groups: int = 32, + scaling_factor: float = 0.476986, + spatial_compression_ratio: int = 8, + temporal_compression_ratio: int = 4, + mid_block_add_attention: bool = True, + ) -> None: + super().__init__() + + self.time_compression_ratio = temporal_compression_ratio + + self.encoder = HunyuanVideoEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + ) + + self.decoder = HunyuanVideoDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + time_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) + + self.spatial_compression_ratio = spatial_compression_ratio + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = True + self.use_framewise_decoding = True + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + self.tile_sample_min_num_frames = 64 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + self.tile_sample_stride_num_frames = 48 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)): + module.gradient_checkpointing = value + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_min_num_frames (`int`, *optional*): + The minimum number of frames required for a sample to be separated into tiles across the frame + dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + tile_sample_stride_num_frames (`int`, *optional*): + The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts + produced across the frame dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + x = self.encoder(x) + enc = self.quant_conv(x) + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile) + else: + tile = self.encoder(tile) + tile = self.quant_conv(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, return_dict=True).sample + else: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + if i > 0: + decoded = decoded[:, :, 1:, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 6a13e80772e3..3a33c8070c08 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,7 @@ from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py new file mode 100644 index 000000000000..d8f9834ea61c --- /dev/null +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -0,0 +1,723 @@ +# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version +from ..attention import FeedForward +from ..attention_processor import Attention, AttentionProcessor +from ..embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + + +class HunyuanVideoAttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + if attn.add_q_proj is None and encoder_hidden_states is not None: + query = torch.cat( + [ + apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + # 5. Attention + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoPatchEmbed(nn.Module): + def __init__( + self, + patch_size: Union[int, Tuple[int, int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + super().__init__() + + patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + return hidden_states + + +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> None: + self_attn_mask = None + if attention_mask is not None: + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + self_attn_mask[:, :, :, 0] = True + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideoTokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideoIndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + hidden_states = self.proj_in(hidden_states) + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + + return hidden_states + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size] + + axes_grids = [] + for i in range(3): + # Note: The following line diverges from original behaviour. We create the grid on the device, whereas + # original implementation creates it on CPU and then moves it to device. This results in numerical + # differences in layerwise debugging outputs, but visually it is the same. + grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + axes_grids.append(grid) + grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + return freqs_cos, freqs_sin + + +class HunyuanVideoSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, + ) + + self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + guidance_embeds: bool = True, + text_embed_dim: int = 4096, + pooled_projection_dim: int = 768, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int] = (16, 56, 56), + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + + # 2. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 4. Single stream transformer blocks + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + pooled_projections: torch.Tensor, + guidance: torch.Tensor = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p, p_t = self.config.patch_size, self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb = self.time_text_embed(timestep, guidance, pooled_projections) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) + + # 3. Attention mask preparation + latent_sequence_length = hidden_states.shape[1] + condition_sequence_length = encoder_hidden_states.shape[1] + sequence_length = latent_sequence_length + condition_sequence_length + attention_mask = torch.zeros( + batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool + ) # [B, N, N] + + effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] + effective_sequence_length = latent_sequence_length + effective_condition_sequence_length + + for i in range(batch_size): + attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 6f1b842f92f2..e7fd7ec78bed 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -214,6 +214,7 @@ "IFSuperResolutionPipeline", ] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] + _import_structure["hunyuan_video"] = ["HunyuanVideoPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", @@ -549,6 +550,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) + from .hunyuan_video import HunyuanVideoPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/hunyuan_video/__init__.py b/src/diffusers/pipelines/hunyuan_video/__init__.py new file mode 100644 index 000000000000..978ed7f96110 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuan_video import HunyuanVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py new file mode 100644 index 000000000000..bd3d3c1e8485 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -0,0 +1,675 @@ +# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "tencent/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideoPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer_2 (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = ( + self.vae.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scale_factor_spatial = ( + self.vae.spatial_compression_ratio if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None and pooled_prompt_embeds is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: List[float] = None, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. Note that the only available HunyuanVideo model is + CFG-distilled, which means that traditional guidance between unconditional and conditional latent is + not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + if pooled_prompt_embeds is not None: + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_latent_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents.to(transformer_dtype) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py new file mode 100644 index 000000000000..c5cb853e3932 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 0f2aad5c5000..4b6ac10385cf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -107,6 +107,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLHunyuanVideo(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLLTXVideo(metaclass=DummyObject): _backends = ["torch"] @@ -377,6 +392,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HunyuanVideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8aefce9d624e..e148c025d191 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -572,6 +572,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class I2VGenXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py new file mode 100644 index 000000000000..826ac30d5f2f --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLHunyuanVideo +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderKLHunyuanVideo + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_hunyuan_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "down_block_types": ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + "up_block_types": ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + "block_out_channels": (8, 8, 8, 8), + "layers_per_block": 1, + "act_fn": "silu", + "norm_num_groups": 4, + "scaling_factor": 0.476986, + "spatial_compression_ratio": 8, + "temporal_compression_ratio": 4, + "mid_block_add_attention": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_hunyuan_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_enable_disable_tiling(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_tiling() + output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), + 0.5, + "VAE tiling should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_tiling() + output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_tiling.detach().cpu().numpy().all(), + output_without_tiling_2.detach().cpu().numpy().all(), + "Without tiling outputs should match with the outputs when tiling is manually disabled.", + ) + + def test_enable_disable_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_slicing() + output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), + 0.5, + "VAE slicing should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_slicing() + output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_slicing.detach().cpu().numpy().all(), + output_without_slicing_2.detach().cpu().numpy().all(), + "Without slicing outputs should match with the outputs when slicing is manually disabled.", + ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "HunyuanVideoDecoder3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoEncoder3D", + "HunyuanVideoMidBlock3D", + "HunyuanVideoUpBlock3D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py new file mode 100644 index 000000000000..e8ea8cecbb9e --- /dev/null +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -0,0 +1,89 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HunyuanVideoTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = HunyuanVideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 1 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + pooled_projection_dim = 8 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) + guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_projections, + "encoder_attention_mask": encoder_attention_mask, + "guidance": guidance, + } + + @property + def input_shape(self): + return (4, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 10, + "num_layers": 1, + "num_single_layers": 1, + "num_refiner_layers": 1, + "patch_size": 1, + "patch_size_t": 1, + "guidance_embeds": True, + "text_embed_dim": 16, + "pooled_projection_dim": 8, + "rope_axes_dim": (2, 4, 4), + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HunyuanVideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/hunyuan_video/__init__.py b/tests/pipelines/hunyuan_video/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py new file mode 100644 index 000000000000..567002268106 --- /dev/null +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -0,0 +1,331 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConfig, LlamaModel, LlamaTokenizer + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HunyuanVideoPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + # there is no xformers processor for Flux + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = HunyuanVideoTransformer3DModel( + in_channels=4, + out_channels=4, + num_attention_heads=2, + attention_head_dim=10, + num_layers=1, + num_single_layers=1, + num_refiner_layers=1, + patch_size=1, + patch_size_t=1, + guidance_embeds=True, + text_embed_dim=16, + pooled_projection_dim=8, + rope_axes_dim=(2, 4, 4), + ) + + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo( + in_channels=3, + out_channels=3, + latent_channels=4, + down_block_types=( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types=( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + layers_per_block=1, + act_fn="silu", + norm_num_groups=4, + scaling_factor=0.476986, + spatial_compression_ratio=8, + temporal_compression_ratio=4, + mid_block_add_attention=True, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + llama_text_encoder_config = LlamaConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=16, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=8, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = LlamaModel(llama_text_encoder_config) + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModel(clip_text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + # Cannot test with dummy prompt because tokenizers are not configured correctly. + # TODO(aryan): create dummy tokenizers and using from hub + inputs = { + "prompt": "", + "prompt_template": { + "template": "{}", + "crop_start": 0, + }, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.5, + "height": 16, + "width": 16, + # 4 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + # Seems to require higher tolerance than the other tests + expected_diff_max = 0.6 + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + # TODO(aryan): Create a dummy gemma model with smol vocab size + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_consistent(self): + pass + + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_single_identical(self): + pass From 5fb3a985173efaae7ff381b9040c386751d643da Mon Sep 17 00:00:00 2001 From: fancy45daddy <124528204+fancy45daddy@users.noreply.github.com> Date: Mon, 16 Dec 2024 01:05:50 -0800 Subject: [PATCH 06/88] Update pipeline_controlnet.py add support for pytorch_xla (#10222) * Update pipeline_controlnet.py * make style --------- Co-authored-by: hlky --- .../pipelines/controlnet/pipeline_controlnet.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 486f9fb764d1..582f51ab480e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -31,6 +31,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -42,6 +43,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1323,6 +1331,8 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: From ea893a9ae73fa3913472f1056358869fa33c46a3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Dec 2024 22:20:27 +0530 Subject: [PATCH 07/88] [Docs] add rest of the lora loader mixins to the docs. (#10230) add rest of the lora loader mixins to the docs. --- docs/source/en/api/loaders/lora.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 2060a1eefd52..5dde55ada562 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -17,6 +17,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`StableDiffusionLoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model. - [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`StableDiffusionLoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model. - [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3). +- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux). +- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox). +- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi). - [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`]. - [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. @@ -38,6 +41,18 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse [[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin +## FluxLoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin + +## CogVideoXLoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin + +## Mochi1LoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin + ## AmusedLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin From 672bd495733ed306ff86fe377d3f75156ece69a6 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 19:24:16 +0000 Subject: [PATCH 08/88] Use `t` instead of `timestep` in `_apply_perturbed_attention_guidance` (#10243) --- src/diffusers/pipelines/pag/pipeline_pag_sana.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 081dbef21e5c..c6e7554e6b69 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -840,7 +840,7 @@ def __call__( # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, guidance_scale, timestep + noise_pred, self.do_classifier_free_guidance, guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) From a7d50524ddf4454ccb5d37f2ec21a7a53bb5c1b7 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 19:25:21 +0000 Subject: [PATCH 09/88] Add `dynamic_shifting` to SD3 (#10236) * Add `dynamic_shifting` to SD3 * calculate_shift * FlowMatchHeunDiscreteScheduler doesn't support mu * Inpaint/img2img --- .../pipeline_stable_diffusion_3.py | 50 ++++++++++++++++--- .../pipeline_stable_diffusion_3_img2img.py | 35 ++++++++++++- .../pipeline_stable_diffusion_3_inpaint.py | 35 ++++++++++++- 3 files changed, 112 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 513f86441c3a..0a51dcbc1261 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -68,6 +68,20 @@ """ +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -702,6 +716,7 @@ def __call__( skip_layer_guidance_scale: int = 2.8, skip_layer_guidance_stop: int = 0.2, skip_layer_guidance_start: int = 0.01, + mu: Optional[float] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -802,6 +817,7 @@ def __call__( `skip_guidance_layers` will start. The guidance will be applied to the layers specified in `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is 0.01. + mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. Examples: @@ -882,12 +898,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # 5. Prepare latent variables + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, @@ -900,6 +911,33 @@ def __call__( latents, ) + # 5. Prepare timesteps + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + _, _, height, width = latents.shape + image_seq_len = (height // self.transformer.config.patch_size) * ( + width // self.transformer.config.patch_size + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 013c31c18e34..c10401324430 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -75,6 +75,20 @@ """ +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -748,6 +762,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, + mu: Optional[float] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -832,6 +847,7 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. Examples: @@ -913,7 +929,24 @@ def __call__( image = self.image_processor.preprocess(image, height=height, width=width) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * ( + int(width) // self.vae_scale_factor // self.transformer.config.patch_size + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index 2b6e42aa5081..ca32880d0df2 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -74,6 +74,20 @@ """ +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -838,6 +852,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, + mu: Optional[float] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -947,6 +962,7 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. Examples: @@ -1023,7 +1039,24 @@ def __call__( pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 3. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * ( + int(width) // self.vae_scale_factor // self.transformer.config.patch_size + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # check that number of inference steps is not < 1 - as this doesn't make sense if num_inference_steps < 1: From 3f421fe09fa47512618287c0b1d306dde93ba9ec Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 19:27:22 +0000 Subject: [PATCH 10/88] Fix `use_flow_sigmas` (#10242) use_flow_sigmas copy --- src/diffusers/schedulers/scheduling_deis_multistep.py | 2 +- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 2 +- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 2 +- src/diffusers/schedulers/scheduling_sasolver.py | 2 +- src/diffusers/schedulers/scheduling_unipc_multistep.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 17d3c25761f0..3350c3373ecf 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -287,7 +287,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic elif self.config.use_flow_sigmas: alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1] + sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 3547b3edd543..64b702bc0e32 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -412,7 +412,7 @@ def set_timesteps( elif self.config.use_flow_sigmas: alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1] + sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 540f7fd84bd7..19399a724a41 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -297,7 +297,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc elif self.config.use_flow_sigmas: alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1] + sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index c300f966dbfb..bf68d6c99bd6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -361,7 +361,7 @@ def set_timesteps( elif self.config.use_flow_sigmas: alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1] + sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index bef6d11973a2..41a471275fa2 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -316,7 +316,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc elif self.config.use_flow_sigmas: alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1] + sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 2f6883c5da6b..c6434c6f87c6 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -379,7 +379,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic elif self.config.use_flow_sigmas: alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1] + sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) From 87e8157437be4f80e2bbbc68f177281820e6f3b4 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 19:29:12 +0000 Subject: [PATCH 11/88] Fix ControlNetUnion _callback_tensor_inputs (#10218) --- .../pipeline_controlnet_union_inpaint_sd_xl.py | 3 --- .../controlnet/pipeline_controlnet_union_sd_xl.py | 9 --------- .../pipeline_controlnet_union_sd_xl_img2img.py | 8 -------- 3 files changed, 20 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py index bfc28615e8b4..7012f3b95458 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py @@ -205,11 +205,8 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( _callback_tensor_inputs = [ "latents", "prompt_embeds", - "negative_prompt_embeds", "add_text_embeds", "add_time_ids", - "negative_pooled_prompt_embeds", - "add_neg_time_ids", "mask", "masked_image_latents", ] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index 78395243f6e4..dcd885f7d604 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -221,12 +221,8 @@ class StableDiffusionXLControlNetUnionPipeline( _callback_tensor_inputs = [ "latents", "prompt_embeds", - "negative_prompt_embeds", "add_text_embeds", "add_time_ids", - "negative_pooled_prompt_embeds", - "negative_add_time_ids", - "image", ] def __init__( @@ -1451,13 +1447,8 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index f36212d70755..95cf067fce12 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -244,11 +244,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( _callback_tensor_inputs = [ "latents", "prompt_embeds", - "negative_prompt_embeds", "add_text_embeds", "add_time_ids", - "negative_pooled_prompt_embeds", - "add_neg_time_ids", ] def __init__( @@ -1566,13 +1563,8 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): From 438bd6054992061a78dd2f470064e16cf7b71abc Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 19:30:26 +0000 Subject: [PATCH 12/88] Use non-human subject in StableDiffusion3ControlNetPipeline example (#10214) * Use non-human subject in StableDiffusion3ControlNetPipeline example * make style --- .../pipeline_stable_diffusion_3_controlnet.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 983fff307755..1de7ba424d54 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -66,9 +66,13 @@ ... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") - >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") - >>> prompt = "A girl holding a sign that says InstantX" - >>> image = pipe(prompt, control_image=control_image, controlnet_conditioning_scale=0.7).images[0] + >>> control_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ... ) + >>> prompt = "A bird in space" + >>> image = pipe( + ... prompt, control_image=control_image, height=1024, width=768, controlnet_conditioning_scale=0.7 + ... ).images[0] >>> image.save("sd3.png") ``` """ From 7186bb45f00adb36a880bd30d41cfddb12faae11 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 19:31:02 +0000 Subject: [PATCH 13/88] Add enable_vae_tiling to AllegroPipeline, fix example (#10212) --- .../pipelines/allegro/pipeline_allegro.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 2be596cf8eb3..b3650dc6cee1 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -59,6 +59,7 @@ >>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32) >>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda") + >>> pipe.enable_vae_tiling() >>> prompt = ( ... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, " @@ -636,6 +637,35 @@ def _prepare_rotary_positional_embeddings( return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + @property def guidance_scale(self): return self._guidance_scale From e9a3911b676fa0ec309999fb89fd5fd686495c42 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 19:31:22 +0000 Subject: [PATCH 14/88] Fix checkpoint in CogView3PlusPipeline example (#10211) --- src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py index 64fff61d2c32..8bed88c275cf 100644 --- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py +++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py @@ -38,7 +38,7 @@ >>> import torch >>> from diffusers import CogView3PlusPipeline - >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3B", torch_dtype=torch.bfloat16) + >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A photo of an astronaut riding a horse on mars" From 2f023d7b84c2a62f5809c0a370ab4f37c4aaef54 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 19:38:13 +0000 Subject: [PATCH 15/88] Fix RePaint Scheduler (#10185) Fix repaint scheduler --- src/diffusers/schedulers/scheduling_repaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index 97665bb5277b..ae953cfb966b 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -319,7 +319,7 @@ def step( prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf - prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise + prev_known_part = (alpha_prod_t_prev**0.5) * original_image + (1 - alpha_prod_t_prev) * noise # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part From 5ed761a6f2a6dad56031f4e3e32223bfbe2dda01 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 20:25:08 +0000 Subject: [PATCH 16/88] Add ControlNetUnion to AutoPipeline from_pretrained (#10219) --- src/diffusers/pipelines/auto_pipeline.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 1d6686e64271..a0f95fe6cdc1 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -18,6 +18,7 @@ from huggingface_hub.utils import validate_hf_hub_args from ..configuration_utils import ConfigMixin +from ..models.controlnets import ControlNetUnionModel from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline from .cogview3 import CogView3PlusPipeline @@ -28,6 +29,9 @@ StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, + StableDiffusionXLControlNetUnionImg2ImgPipeline, + StableDiffusionXLControlNetUnionInpaintPipeline, + StableDiffusionXLControlNetUnionPipeline, ) from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .flux import ( @@ -108,6 +112,7 @@ ("kandinsky3", Kandinsky3Pipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), + ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline), ("wuerstchen", WuerstchenCombinedPipeline), ("cascade", StableCascadeCombinedPipeline), ("lcm", LatentConsistencyModelPipeline), @@ -139,6 +144,7 @@ ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), ("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), + ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), @@ -158,6 +164,7 @@ ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), + ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("flux", FluxInpaintPipeline), ("flux-controlnet", FluxControlNetInpaintPipeline), @@ -396,7 +403,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): orig_class_name = config["_class_name"] if "controlnet" in kwargs: - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + if isinstance(kwargs["controlnet"], ControlNetUnionModel): + orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline") + else: + orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: @@ -688,7 +698,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" if "controlnet" in kwargs: - orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) + if isinstance(kwargs["controlnet"], ControlNetUnionModel): + orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) + else: + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: @@ -985,7 +998,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" if "controlnet" in kwargs: - orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) + if isinstance(kwargs["controlnet"], ControlNetUnionModel): + orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) + else: + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: From aafed3f8dd042bfe786f6c3e902c5cdb5de1fb08 Mon Sep 17 00:00:00 2001 From: Kaiwen Sheng Date: Mon, 16 Dec 2024 15:25:16 -0800 Subject: [PATCH 17/88] fix downsample bug in MidResTemporalBlock1D (#10250) --- src/diffusers/models/unets/unet_1d_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_1d_blocks.py b/src/diffusers/models/unets/unet_1d_blocks.py index 8fc27e94c474..f08e6070845e 100644 --- a/src/diffusers/models/unets/unet_1d_blocks.py +++ b/src/diffusers/models/unets/unet_1d_blocks.py @@ -217,7 +217,7 @@ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tens if self.upsample: hidden_states = self.upsample(hidden_states) if self.downsample: - self.downsample = self.downsample(hidden_states) + hidden_states = self.downsample(hidden_states) return hidden_states From 9f00c617a0bc50527c1498c36fde066f995a79dd Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Dec 2024 05:05:40 +0530 Subject: [PATCH 18/88] [core] TorchAO Quantizer (#10009) * torchao quantizer --------- Co-authored-by: Sayak Paul Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/quantization.md | 4 + docs/source/en/quantization/overview.md | 2 +- docs/source/en/quantization/torchao.md | 92 +++ src/diffusers/__init__.py | 4 +- src/diffusers/models/model_loading_utils.py | 6 +- src/diffusers/models/modeling_utils.py | 11 +- src/diffusers/quantizers/auto.py | 5 +- .../quantizers/quantization_config.py | 258 +++++++- src/diffusers/quantizers/torchao/__init__.py | 15 + .../quantizers/torchao/torchao_quantizer.py | 280 ++++++++ src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 19 + src/diffusers/utils/testing_utils.py | 13 + tests/quantization/torchao/README.md | 53 ++ tests/quantization/torchao/test_torchao.py | 625 ++++++++++++++++++ 16 files changed, 1374 insertions(+), 16 deletions(-) create mode 100644 docs/source/en/quantization/torchao.md create mode 100644 src/diffusers/quantizers/torchao/__init__.py create mode 100644 src/diffusers/quantizers/torchao/torchao_quantizer.py create mode 100644 tests/quantization/torchao/README.md create mode 100644 tests/quantization/torchao/test_torchao.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d1404a1d6ea6..4edeb9fcb389 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -157,6 +157,8 @@ title: Getting Started - local: quantization/bitsandbytes title: bitsandbytes + - local: quantization/torchao + title: torchao title: Quantization Methods - sections: - local: optimization/fp16 diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 2fbde9e707ea..18aadf3111bd 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui [[autodoc]] BitsAndBytesConfig +## TorchAoConfig + +[[autodoc]] TorchAoConfig + ## DiffusersQuantizer [[autodoc]] quantizers.base.DiffusersQuantizer diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index d8adbc85a259..151b22a607a4 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be ## When to use what? -This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. \ No newline at end of file +Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use. \ No newline at end of file diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md new file mode 100644 index 000000000000..bd5c7697a0f7 --- /dev/null +++ b/docs/source/en/quantization/torchao.md @@ -0,0 +1,92 @@ + + +# torchao + +[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more. + +Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed. + +```bash +pip install -U torch torchao +``` + + +Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. + +The example below only quantizes the weights to int8. + +```python +from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig + +model_id = "black-forest-labs/Flux.1-Dev" +dtype = torch.bfloat16 + +quantization_config = TorchAoConfig("int8wo") +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=dtype, +) +pipe = FluxPipeline.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=dtype, +) +pipe.to("cuda") + +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0] +image.save("output.png") +``` + +TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code. + +```python +# In the above code, add the following after initializing the transformer +transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) +``` + +For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware. + +torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future. + +The `TorchAoConfig` class accepts three parameters: +- `quant_type`: A string value mentioning one of the quantization types below. +- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`. +- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`. + +## Supported quantization types + +torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7. + +Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation. + +Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly. + +The quantization methods supported are as follows: + +| **Category** | **Full Function Names** | **Shorthands** | +|--------------|-------------------------|----------------| +| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` | +| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` | +| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` | +| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` | + +Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations. + +Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available. + +## Resources + +- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) +- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index dfa7a4df2d08..fc7ada80a63b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,7 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers.quantization_config": ["BitsAndBytesConfig"], + "quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -569,7 +569,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin - from .quantizers.quantization_config import BitsAndBytesConfig + from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig try: if not is_onnx_available(): diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 751117f8f247..546c0eb4d840 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -25,7 +25,6 @@ import torch from huggingface_hub.utils import EntryNotFoundError -from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, @@ -182,7 +181,6 @@ def load_model_dict_into_meta( device = device or torch.device("cpu") dtype = dtype or torch.float32 is_quantized = hf_quantizer is not None - is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) empty_state_dict = model.state_dict() @@ -215,12 +213,12 @@ def load_model_dict_into_meta( # bnb params are flattened. if empty_state_dict[param_name].shape != param.shape: if ( - is_quant_method_bnb + is_quantized and hf_quantizer.pre_quantized and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) ): hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) - elif not is_quant_method_bnb: + else: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4fe457706473..ce5289e3dbfd 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -700,10 +700,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer = None if hf_quantizer is not None: - if device_map is not None: + is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" + if is_bnb_quantization_method and device_map is not None: raise NotImplementedError( - "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future." + "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." ) + hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) @@ -858,13 +860,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if device_map is None and not is_sharded: # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. # It would error out during the `validate_environment()` call above in the absence of cuda. - is_quant_method_bnb = ( - getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES - ) if hf_quantizer is None: param_device = "cpu" # TODO (sayakpaul, SunMarc): remove this after model loading refactor - elif is_quant_method_bnb: + else: param_device = torch.device(torch.cuda.current_device()) state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 97cbcdc0e53f..098308ae0bdc 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -19,17 +19,20 @@ from typing import Dict, Optional, Union from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer -from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod +from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig +from .torchao import TorchAoHfQuantizer AUTO_QUANTIZER_MAPPING = { "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, + "torchao": TorchAoHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { "bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig, + "torchao": TorchAoConfig, } diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index f521c5d717d6..4aeb75ab704c 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -22,15 +22,17 @@ import copy import importlib.metadata +import inspect import json import os from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Union +from functools import partial +from typing import Any, Dict, List, Optional, Union from packaging import version -from ..utils import is_torch_available, logging +from ..utils import is_torch_available, is_torchao_available, logging if is_torch_available(): @@ -41,6 +43,7 @@ class QuantizationMethod(str, Enum): BITS_AND_BYTES = "bitsandbytes" + TORCHAO = "torchao" @dataclass @@ -389,3 +392,254 @@ def to_diff_dict(self) -> Dict[str, Any]: serializable_config_dict[key] = value return serializable_config_dict + + +@dataclass +class TorchAoConfig(QuantizationConfigMixin): + """This is a config class for torchao quantization/sparsity techniques. + + Args: + quant_type (`str`): + The type of quantization we want to use, currently supporting: + - **Integer quantization:** + - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, + `int8_weight_only`, `int8_dynamic_activation_int8_weight` + - Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq` + + - **Floating point 8-bit quantization:** + - Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`, + `float8_static_activation_float8_weight` + - Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, + `float8_e4m3_tensor`, `float8_e4m3_row`, + + - **Floating point X-bit quantization:** + - Full function names: `fpx_weight_only` + - Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number + of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must + be satisfied for a given shorthand notation. + + - **Unsigned Integer quantization:** + - Full function names: `uintx_weight_only` + - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` + modules_to_not_convert (`List[str]`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have some + modules left in their original precision. + kwargs (`Dict[str, Any]`, *optional*): + The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization + supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and + documentation of arguments can be found in + https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques + + Example: + ```python + from diffusers import FluxTransformer2DModel, TorchAoConfig + + quantization_config = TorchAoConfig("int8wo") + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/Flux.1-Dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + ``` + """ + + def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None: + self.quant_method = QuantizationMethod.TORCHAO + self.quant_type = quant_type + self.modules_to_not_convert = modules_to_not_convert + + # When we load from serialized config, "quant_type_kwargs" will be the key + if "quant_type_kwargs" in kwargs: + self.quant_type_kwargs = kwargs["quant_type_kwargs"] + else: + self.quant_type_kwargs = kwargs + + TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() + if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): + raise ValueError( + f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the " + f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type] + signature = inspect.signature(method) + all_kwargs = { + param.name + for param in signature.parameters.values() + if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] + } + unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs) + + if len(unsupported_kwargs) > 0: + raise ValueError( + f'The quantization method "{quant_type}" does not support the following keyword arguments: ' + f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." + ) + + @classmethod + def _get_torchao_quant_type_to_method(cls): + r""" + Returns supported torchao quantization types with all commonly used notations. + """ + + if is_torchao_available(): + # TODO(aryan): Support autoquant and sparsify + from torchao.quantization import ( + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, + fpx_weight_only, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, + uintx_weight_only, + ) + + # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers + from torchao.quantization.observer import PerRow, PerTensor + + def generate_float8dq_types(dtype: torch.dtype): + name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3" + types = {} + + for granularity_cls in [PerTensor, PerRow]: + # Note: Activation and Weights cannot have different granularities + granularity_name = "tensor" if granularity_cls is PerTensor else "row" + types[f"float8dq_{name}_{granularity_name}"] = partial( + float8_dynamic_activation_float8_weight, + activation_dtype=dtype, + weight_dtype=dtype, + granularity=(granularity_cls(), granularity_cls()), + ) + + return types + + def generate_fpx_quantization_types(bits: int): + types = {} + + for ebits in range(1, bits): + mbits = bits - ebits - 1 + types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) + + non_sign_bits = bits - 1 + default_ebits = (non_sign_bits + 1) // 2 + default_mbits = non_sign_bits - default_ebits + types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits) + + return types + + INT4_QUANTIZATION_TYPES = { + # int4 weight + bfloat16/float16 activation + "int4wo": int4_weight_only, + "int4_weight_only": int4_weight_only, + # int4 weight + int8 activation + "int4dq": int8_dynamic_activation_int4_weight, + "int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight, + } + + INT8_QUANTIZATION_TYPES = { + # int8 weight + bfloat16/float16 activation + "int8wo": int8_weight_only, + "int8_weight_only": int8_weight_only, + # int8 weight + int8 activation + "int8dq": int8_dynamic_activation_int8_weight, + "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, + } + + # TODO(aryan): handle torch 2.2/2.3 + FLOATX_QUANTIZATION_TYPES = { + # float8_e5m2 weight + bfloat16/float16 activation + "float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), + "float8_weight_only": float8_weight_only, + "float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), + # float8_e4m3 weight + bfloat16/float16 activation + "float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), + # float8_e5m2 weight + float8 activation (dynamic) + "float8dq": float8_dynamic_activation_float8_weight, + "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, + # ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out. + # However, changing activation_dtype=torch.float8_e4m3 might work here ===== + # "float8dq_e5m2": partial( + # float8_dynamic_activation_float8_weight, + # activation_dtype=torch.float8_e5m2, + # weight_dtype=torch.float8_e5m2, + # ), + # **generate_float8dq_types(torch.float8_e5m2), + # ===== ===== + # float8_e4m3 weight + float8 activation (dynamic) + "float8dq_e4m3": partial( + float8_dynamic_activation_float8_weight, + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + ), + **generate_float8dq_types(torch.float8_e4m3fn), + # float8 weight + float8 activation (static) + "float8_static_activation_float8_weight": float8_static_activation_float8_weight, + # For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly + # fpx weight + bfloat16/float16 activation + **generate_fpx_quantization_types(3), + **generate_fpx_quantization_types(4), + **generate_fpx_quantization_types(5), + **generate_fpx_quantization_types(6), + **generate_fpx_quantization_types(7), + } + + UINTX_QUANTIZATION_DTYPES = { + "uintx_weight_only": uintx_weight_only, + "uint1wo": partial(uintx_weight_only, dtype=torch.uint1), + "uint2wo": partial(uintx_weight_only, dtype=torch.uint2), + "uint3wo": partial(uintx_weight_only, dtype=torch.uint3), + "uint4wo": partial(uintx_weight_only, dtype=torch.uint4), + "uint5wo": partial(uintx_weight_only, dtype=torch.uint5), + "uint6wo": partial(uintx_weight_only, dtype=torch.uint6), + "uint7wo": partial(uintx_weight_only, dtype=torch.uint7), + # "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported + } + + QUANTIZATION_TYPES = {} + QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES) + QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) + QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) + + if cls._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) + + return QUANTIZATION_TYPES + else: + raise ValueError( + "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" + ) + + @staticmethod + def _is_cuda_capability_atleast_8_9() -> bool: + if not torch.cuda.is_available(): + raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.") + + major, minor = torch.cuda.get_device_capability() + if major == 8: + return minor >= 9 + return major >= 9 + + def get_apply_tensor_subclass(self): + TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() + return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs) + + def __repr__(self): + r""" + Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`: + + ``` + TorchAoConfig { + "modules_to_not_convert": null, + "quant_method": "torchao", + "quant_type": "uint_a16w4", + "quant_type_kwargs": { + "group_size": 32 + } + } + ``` + """ + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" diff --git a/src/diffusers/quantizers/torchao/__init__.py b/src/diffusers/quantizers/torchao/__init__.py new file mode 100644 index 000000000000..09e6a19d4df0 --- /dev/null +++ b/src/diffusers/quantizers/torchao/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .torchao_quantizer import TorchAoHfQuantizer diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py new file mode 100644 index 000000000000..8b28a403e6f0 --- /dev/null +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -0,0 +1,280 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/src/transformers/quantizers/quantizer_torchao.py +""" + +import importlib +import types +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from packaging import version + +from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + import torch.nn as nn + + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + # At the moment, only int8 is supported for integer quantization dtypes. + # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future + # to support more quantization methods, such as intx_weight_only. + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + ) + +if is_torchao_available(): + from torchao.quantization import quantize_ + + +logger = logging.get_logger(__name__) + + +def _quantization_type(weight): + from torchao.dtypes import AffineQuantizedTensor + from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + + if isinstance(weight, AffineQuantizedTensor): + return f"{weight.__class__.__name__}({weight._quantization_type()})" + + if isinstance(weight, LinearActivationQuantizedTensor): + return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" + + +def _linear_extra_repr(self): + weight = _quantization_type(self.weight) + if weight is None: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None" + else: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}" + + +class TorchAoHfQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for TorchAO: https://github.com/pytorch/ao/. + """ + + requires_calibration = False + required_packages = ["torchao"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not is_torchao_available(): + raise ImportError( + "Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`" + ) + + self.offload = False + + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict): + if "cpu" in device_map.values() or "disk" in device_map.values(): + if self.pre_quantized: + raise ValueError( + "You are attempting to perform cpu/disk offload with a pre-quantized torchao model " + "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument." + ) + else: + self.offload = True + + if self.pre_quantized: + weights_only = kwargs.get("weights_only", None) + if weights_only: + torch_version = version.parse(importlib.metadata.version("torch")) + if torch_version < version.parse("2.5.0"): + # TODO(aryan): TorchAO is compatible with Pytorch >= 2.2 for certain quantization types. Try to see if we can support it in future + raise RuntimeError( + f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}." + ) + + def update_torch_dtype(self, torch_dtype): + quant_type = self.quantization_config.quant_type + + if quant_type.startswith("int"): + if torch_dtype is not None and torch_dtype != torch.bfloat16: + logger.warning( + f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " + f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." + ) + + if torch_dtype is None: + # We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op + logger.warning( + "Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` " + "to enable model loading in different precisions. Pass your own `torch_dtype` to specify the " + "dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning." + ) + torch_dtype = torch.bfloat16 + + return torch_dtype + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + quant_type = self.quantization_config.quant_type + + if quant_type.startswith("int8") or quant_type.startswith("int4"): + # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8 + return torch.int8 + elif quant_type == "uintx_weight_only": + return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8) + elif quant_type.startswith("uint"): + return { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + }[int(quant_type[4])] + elif quant_type.startswith("float") or quant_type.startswith("fp"): + return torch.bfloat16 + + if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION): + return target_dtype + + # We need one of the supported dtypes to be selected in order for accelerate to determine + # the total size of modules/parameters for auto device placement. + possible_device_maps = ["auto", "balanced", "balanced_low_0", "sequential"] + raise ValueError( + f"You have set `device_map` as one of {possible_device_maps} on a TorchAO quantized model but a suitable target dtype " + f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the " + f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + max_memory = {key: val * 0.9 for key, val in max_memory.items()} + return max_memory + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + param_device = kwargs.pop("param_device", None) + # Check if the param_name is not in self.modules_to_not_convert + if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): + return False + elif param_device == "cpu" and self.offload: + # We don't quantize weights that we offload + return False + else: + # We only quantize the weight of nn.Linear + module, tensor_name = get_module_from_name(model, param_name) + return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: List[str], + ): + r""" + Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor, + then we move it to the target device. Finally, we quantize the module. + """ + module, tensor_name = get_module_from_name(model, param_name) + + if self.pre_quantized: + # If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info + # about AffineQuantizedTensor + module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + else: + # As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves + module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + quantize_(module, self.quantization_config.get_apply_tensor_subclass()) + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + self.modules_to_not_convert.extend(keys_on_cpu) + + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "ModelMixin"): + return model + + def is_serializable(self, safe_serialization=None): + # TODO(aryan): needs to be tested + if safe_serialization: + logger.warning( + "torchao quantized model does not support safe serialization, please set `safe_serialization` to False." + ) + return False + + _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( + "0.25.0" + ) + + if not _is_torchao_serializable: + logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ") + + if self.offload and self.quantization_config.modules_to_not_convert is None: + logger.warning( + "The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them." + "If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config." + ) + return False + + return _is_torchao_serializable + + @property + def is_trainable(self): + return self.quantization_config.quant_type.startswith("int8") diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f91cee8113f2..9860ac849834 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -87,6 +87,7 @@ is_torch_version, is_torch_xla_available, is_torch_xla_version, + is_torchao_available, is_torchsde_available, is_torchvision_available, is_transformers_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index e3b7655737a8..f325f36bddd3 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -340,6 +340,15 @@ def is_timm_available(): _imageio_available = False +_is_torchao_available = importlib.util.find_spec("torchao") is not None +if _is_torchao_available: + try: + _torchao_version = importlib_metadata.version("torchao") + logger.debug(f"Successfully import torchao version {_torchao_version}") + except importlib_metadata.PackageNotFoundError: + _is_torchao_available = False + + def is_torch_available(): return _torch_available @@ -460,6 +469,10 @@ def is_imageio_available(): return _imageio_available +def is_torchao_available(): + return _is_torchao_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -593,6 +606,11 @@ def is_imageio_available(): {0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg` """ +# docstyle-ignore +TORCHAO_IMPORT_ERROR = """ +{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao` +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -618,6 +636,7 @@ def is_imageio_available(): ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), + ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index b3e381f7d3fb..b4d3415de50e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -39,6 +39,7 @@ is_timm_available, is_torch_available, is_torch_version, + is_torchao_available, is_torchsde_available, is_transformers_available, ) @@ -476,6 +477,18 @@ def decorator(test_case): return decorator +def require_torchao_version_greater(torchao_version): + def decorator(test_case): + correct_torchao_version = is_torchao_available() and version.parse( + version.parse(importlib.metadata.version("torchao")).base_version + ) > version.parse(torchao_version) + return unittest.skipUnless( + correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}." + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend diff --git a/tests/quantization/torchao/README.md b/tests/quantization/torchao/README.md new file mode 100644 index 000000000000..fadc529e12fc --- /dev/null +++ b/tests/quantization/torchao/README.md @@ -0,0 +1,53 @@ +The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/tests/quantization/torchao_integration/). + +The benchmarks were run on a single H100. Below is `nvidia-smi`: + +```bash ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.2 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:53:00.0 Off | 0 | +| N/A 34C P0 69W / 700W | 2MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| +| No running processes found | ++---------------------------------------------------------------------------------------+ +``` + +The benchmark results for Flux and CogVideoX can be found in [this](https://github.com/huggingface/diffusers/pull/10009) PR. + +The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent: + +```bash +HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests +``` + +`diffusers-cli`: + +```bash +- 🤗 Diffusers version: 0.32.0.dev0 +- Platform: Linux-5.15.0-1049-aws-x86_64-with-glibc2.31 +- Running on Google Colab?: No +- Python version: 3.10.14 +- PyTorch version (GPU?): 2.6.0.dev20241112+cu121 (False) +- Flax version (CPU?/GPU?/TPU?): not installed (NA) +- Jax version: not installed +- JaxLib version: not installed +- Huggingface_hub version: 0.26.2 +- Transformers version: 4.46.3 +- Accelerate version: 1.1.1 +- PEFT version: not installed +- Bitsandbytes version: not installed +- Safetensors version: 0.4.5 +- xFormers version: not installed +``` diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py new file mode 100644 index 000000000000..5c71fc4e0ae7 --- /dev/null +++ b/tests/quantization/torchao/test_torchao.py @@ -0,0 +1,625 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest +from typing import List + +import numpy as np +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, + TorchAoConfig, +) +from diffusers.models.attention_processor import Attention +from diffusers.utils.testing_utils import ( + enable_full_determinism, + is_torch_available, + is_torchao_available, + nightly, + require_torch, + require_torch_gpu, + require_torchao_version_greater, + slow, + torch_device, +) + + +enable_full_determinism() + + +if is_torch_available(): + import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + +if is_torchao_available(): + from torchao.dtypes import AffineQuantizedTensor + from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType + from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + + +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +class TorchAoConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Makes sure the config format is properly set + """ + quantization_config = TorchAoConfig("int4_weight_only") + torchao_orig_config = quantization_config.to_dict() + + for key in torchao_orig_config: + self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key]) + + def test_post_init_check(self): + """ + Test kwargs validations in TorchAoConfig + """ + _ = TorchAoConfig("int4_weight_only") + with self.assertRaisesRegex(ValueError, "is not supported yet"): + _ = TorchAoConfig("uint8") + + with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"): + _ = TorchAoConfig("int4_weight_only", group_size1=32) + + def test_repr(self): + """ + Check that there is no error in the repr + """ + quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8) + expected_repr = """TorchAoConfig { + "modules_to_not_convert": [ + "conv" + ], + "quant_method": "torchao", + "quant_type": "int4_weight_only", + "quant_type_kwargs": { + "group_size": 8 + } + }""".replace(" ", "").replace("\n", "") + quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") + self.assertEqual(quantization_repr, expected_repr) + + +# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +class TorchAoTest(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_components(self, quantization_config: TorchAoConfig): + model_id = "hf-internal-testing/tiny-flux-pipe" + transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") + text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 32, + "width": 32, + "num_inference_steps": 2, + "output_type": "np", + "generator": generator, + } + + return inputs + + def get_dummy_tensor_inputs(self, device=None, seed: int = 0): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + torch.manual_seed(seed) + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + device, dtype=torch.bfloat16 + ) + + torch.manual_seed(seed) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + + def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]): + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device, dtype=torch.bfloat16) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + output_slice = output[-1, -1, -3:, -3:].flatten() + + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_quantization(self): + # fmt: off + QUANTIZATION_TYPES_TO_TEST = [ + ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), + ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), + ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), + ("int_a8w8", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint_a16w7", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ] + + if TorchAoConfig._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), + ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), + # ===== + # The following lead to an internal torch error: + # RuntimeError: mat2 shape (32x4 must be divisible by 16 + # Skip these for now; TODO(aryan): investigate later + # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + # Cutlass fails to initialize for below + # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ]) + # fmt: on + + for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quant_kwargs = {} + if quantization_name in ["uint4wo", "uint_a16w7"]: + # The dummy flux model that we use requires us to impose some restrictions on group_size here + quant_kwargs.update({"group_size": 16}) + quantization_config = TorchAoConfig( + quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs + ) + self._test_quant_type(quantization_config, expected_slice) + + def test_int4wo_quant_bfloat16_conversion(self): + """ + Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization. + """ + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + weight = quantized_model.transformer_blocks[0].ff.net[2].weight + self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + self.assertEqual(weight.quant_min, 0) + self.assertEqual(weight.quant_max, 15) + self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) + + def test_offload(self): + """ + Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies + that the device map is correctly set (in the `hf_device_map` attribute of the model). + """ + + device_map_offload = { + "time_text_embed": torch_device, + "context_embedder": torch_device, + "x_embedder": torch_device, + "transformer_blocks.0": "cpu", + "single_transformer_blocks.0": "disk", + "norm_out": torch_device, + "proj_out": "cpu", + } + + inputs = self.get_dummy_tensor_inputs(torch_device) + + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map_offload, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + self.assertTrue(quantized_model.hf_device_map == device_map_offload) + + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + + expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_modules_to_not_convert(self): + quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2] + self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear)) + self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor)) + self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16) + + quantized_layer = quantized_model.proj_out + self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor)) + self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8) + + def test_training(self): + quantization_config = TorchAoConfig("int8_weight_only") + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + + for param in quantized_model.parameters(): + # freeze the model as only adapter layers will be trained + param.requires_grad = False + if param.ndim == 1: + param.data = param.data.to(torch.float32) + + for _, module in quantized_model.named_modules(): + if isinstance(module, Attention): + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output.norm().backward() + + for module in quantized_model.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + @nightly + def test_torch_compile(self): + r"""Test that verifies if torch.compile works with torchao quantization.""" + quantization_config = TorchAoConfig("int8_weight_only") + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device, dtype=torch.bfloat16) + + inputs = self.get_dummy_inputs(torch_device) + normal_output = pipe(**inputs)[0].flatten()[-32:] + + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) + inputs = self.get_dummy_inputs(torch_device) + compile_output = pipe(**inputs)[0].flatten()[-32:] + + # Note: Seems to require higher tolerance + self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) + + @staticmethod + def _get_memory_footprint(module): + quantized_param_memory = 0.0 + unquantized_param_memory = 0.0 + + for param in module.parameters(): + if param.__class__.__name__ == "AffineQuantizedTensor": + data, scale, zero_point = param.layout_tensor.get_plain() + quantized_param_memory += data.numel() + data.element_size() + quantized_param_memory += scale.numel() + scale.element_size() + quantized_param_memory += zero_point.numel() + zero_point.element_size() + else: + unquantized_param_memory += param.data.numel() * param.data.element_size() + + total_memory = quantized_param_memory + unquantized_param_memory + return total_memory, quantized_param_memory, unquantized_param_memory + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"] + transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"] + transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"] + transformer_bf16 = self.get_dummy_components(None)["transformer"] + + total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo) + total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint( + transformer_int4wo_gs32 + ) + total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo) + total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16) + + self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16) + # int4wo_gs32 has smaller group size, so more groups -> more scales and zero points + self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32) + # int4 with default group size quantized very few linear layers compared to a smaller group size of 32 + self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32) + # int8 quantizes more layers compare to int4 with default group size + self.assertTrue(quantized_int8wo < quantized_int4wo) + + def test_wrong_config(self): + with self.assertRaises(ValueError): + self.get_dummy_components(TorchAoConfig("int42")) + + +# This class is not to be run as a test by itself. See the tests that follow this class +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +class TorchAoSerializationTest(unittest.TestCase): + model_name = "hf-internal-testing/tiny-flux-pipe" + quant_method, quant_method_kwargs = None, None + device = "cuda" + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_model(self, device=None): + quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs) + quantized_model = FluxTransformer2DModel.from_pretrained( + self.model_name, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + return quantized_model.to(device) + + def get_dummy_tensor_inputs(self, device=None, seed: int = 0): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + torch.manual_seed(seed) + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + device, dtype=torch.bfloat16 + ) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + + def test_original_model_expected_slice(self): + quantized_model = self.get_dummy_model(torch_device) + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3)) + + def check_serialization_expected_slice(self, expected_slice): + quantized_model = self.get_dummy_model(self.device) + + with tempfile.TemporaryDirectory() as tmp_dir: + quantized_model.save_pretrained(tmp_dir, safe_serialization=False) + loaded_quantized_model = FluxTransformer2DModel.from_pretrained( + tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False + ) + + inputs = self.get_dummy_tensor_inputs(torch_device) + output = loaded_quantized_model(**inputs)[0] + + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue( + isinstance( + loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor) + ) + ) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_serialization_expected_slice(self): + self.check_serialization_expected_slice(self.serialized_expected_slice) + + +class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + serialized_expected_slice = expected_slice + device = "cuda" + + +class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_weight_only", {} + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) + serialized_expected_slice = expected_slice + device = "cuda" + + +class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + serialized_expected_slice = expected_slice + device = "cpu" + + +class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_weight_only", {} + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) + serialized_expected_slice = expected_slice + device = "cpu" + + +# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +@slow +@nightly +class SlowTorchAoTests(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_components(self, quantization_config: TorchAoConfig): + model_id = "black-forest-labs/FLUX.1-dev" + transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") + text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 512, + "width": 512, + "num_inference_steps": 20, + "output_type": "np", + "generator": generator, + } + + return inputs + + def _test_quant_type(self, quantization_config, expected_slice): + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components).to(dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten() + output_slice = np.concatenate((output[:16], output[-16:])) + + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_quantization(self): + # fmt: off + QUANTIZATION_TYPES_TO_TEST = [ + ("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])), + ("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])), + ] + + if TorchAoConfig._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), + ("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])), + ]) + # fmt: on + + for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"]) + self._test_quant_type(quantization_config, expected_slice) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() From 7667cfcb41dfeb8f217e4314dcf2d561b8ca41d2 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:36:26 -0800 Subject: [PATCH 19/88] [docs] Add missing AttnProcessors (#10246) * attnprocessors * lora * make style * fix * fix * sana * typo --- docs/source/en/api/attnprocessor.md | 115 ++++++++++++++++++-- src/diffusers/models/attention_processor.py | 16 +++ 2 files changed, 120 insertions(+), 11 deletions(-) diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index 5b1f0be72ae6..fee0d7e35764 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -15,40 +15,133 @@ specific language governing permissions and limitations under the License. An attention processor is a class for applying different types of attention mechanisms. ## AttnProcessor + [[autodoc]] models.attention_processor.AttnProcessor -## AttnProcessor2_0 [[autodoc]] models.attention_processor.AttnProcessor2_0 -## AttnAddedKVProcessor [[autodoc]] models.attention_processor.AttnAddedKVProcessor -## AttnAddedKVProcessor2_0 [[autodoc]] models.attention_processor.AttnAddedKVProcessor2_0 +[[autodoc]] models.attention_processor.AttnProcessorNPU + +[[autodoc]] models.attention_processor.FusedAttnProcessor2_0 + +## Allegro + +[[autodoc]] models.attention_processor.AllegroAttnProcessor2_0 + +## AuraFlow + +[[autodoc]] models.attention_processor.AuraFlowAttnProcessor2_0 + +[[autodoc]] models.attention_processor.FusedAuraFlowAttnProcessor2_0 + +## CogVideoX + +[[autodoc]] models.attention_processor.CogVideoXAttnProcessor2_0 + +[[autodoc]] models.attention_processor.FusedCogVideoXAttnProcessor2_0 + ## CrossFrameAttnProcessor + [[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor -## CustomDiffusionAttnProcessor +## Custom Diffusion + [[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor -## CustomDiffusionAttnProcessor2_0 [[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0 -## CustomDiffusionXFormersAttnProcessor [[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor -## FusedAttnProcessor2_0 -[[autodoc]] models.attention_processor.FusedAttnProcessor2_0 +## Flux + +[[autodoc]] models.attention_processor.FluxAttnProcessor2_0 + +[[autodoc]] models.attention_processor.FusedFluxAttnProcessor2_0 + +[[autodoc]] models.attention_processor.FluxSingleAttnProcessor2_0 + +## Hunyuan + +[[autodoc]] models.attention_processor.HunyuanAttnProcessor2_0 + +[[autodoc]] models.attention_processor.FusedHunyuanAttnProcessor2_0 + +[[autodoc]] models.attention_processor.PAGHunyuanAttnProcessor2_0 + +[[autodoc]] models.attention_processor.PAGCFGHunyuanAttnProcessor2_0 + +## IdentitySelfAttnProcessor2_0 + +[[autodoc]] models.attention_processor.PAGIdentitySelfAttnProcessor2_0 + +[[autodoc]] models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0 + +## IP-Adapter + +[[autodoc]] models.attention_processor.IPAdapterAttnProcessor + +[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0 + +## JointAttnProcessor2_0 + +[[autodoc]] models.attention_processor.JointAttnProcessor2_0 + +[[autodoc]] models.attention_processor.PAGJointAttnProcessor2_0 + +[[autodoc]] models.attention_processor.PAGCFGJointAttnProcessor2_0 + +[[autodoc]] models.attention_processor.FusedJointAttnProcessor2_0 + +## LoRA + +[[autodoc]] models.attention_processor.LoRAAttnProcessor + +[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0 + +[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor + +[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor + +## Lumina-T2X + +[[autodoc]] models.attention_processor.LuminaAttnProcessor2_0 + +## Mochi + +[[autodoc]] models.attention_processor.MochiAttnProcessor2_0 + +[[autodoc]] models.attention_processor.MochiVaeAttnProcessor2_0 + +## Sana + +[[autodoc]] models.attention_processor.SanaLinearAttnProcessor2_0 + +[[autodoc]] models.attention_processor.SanaMultiscaleAttnProcessor2_0 + +[[autodoc]] models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0 + +[[autodoc]] models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0 + +## Stable Audio + +[[autodoc]] models.attention_processor.StableAudioAttnProcessor2_0 ## SlicedAttnProcessor + [[autodoc]] models.attention_processor.SlicedAttnProcessor -## SlicedAttnAddedKVProcessor [[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor ## XFormersAttnProcessor + [[autodoc]] models.attention_processor.XFormersAttnProcessor -## AttnProcessorNPU -[[autodoc]] models.attention_processor.AttnProcessorNPU +[[autodoc]] models.attention_processor.XFormersAttnAddedKVProcessor + +## XLAFlashAttnProcessor2_0 + +[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0 diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ee6b010519e2..be8d654ca66a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5423,21 +5423,37 @@ def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Ten class LoRAAttnProcessor: + r""" + Processor for implementing attention with LoRA. + """ + def __init__(self): pass class LoRAAttnProcessor2_0: + r""" + Processor for implementing attention with LoRA (enabled by default if you're using PyTorch 2.0). + """ + def __init__(self): pass class LoRAXFormersAttnProcessor: + r""" + Processor for implementing attention with LoRA using xFormers. + """ + def __init__(self): pass class LoRAAttnAddedKVProcessor: + r""" + Processor for implementing attention with LoRA with extra learnable key and value matrices for the text encoder. + """ + def __init__(self): pass From 6fb94d51cb8757aa00a62f9827b5b55e2856b2d3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 17 Dec 2024 09:17:40 +0530 Subject: [PATCH 20/88] [chore] add contribution note for lawrence. (#10253) add contribution note for lawrence. --- docs/source/en/api/models/autoencoder_dc.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/api/models/autoencoder_dc.md b/docs/source/en/api/models/autoencoder_dc.md index 667f0de678f6..6f86150eb744 100644 --- a/docs/source/en/api/models/autoencoder_dc.md +++ b/docs/source/en/api/models/autoencoder_dc.md @@ -29,6 +29,8 @@ The following DCAE models are released and supported in Diffusers. | [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0) | [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0) +This model was contributed by [lawrence-cj](https://github.com/lawrence-cj). + Load a model in Diffusers format with [`~ModelMixin.from_pretrained`]. ```python From 0d96a894a766198ef2b2d5266e646dd958081cc0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Dec 2024 11:09:57 +0530 Subject: [PATCH 21/88] Fix copied from comment in Mochi lora loader (#10255) update --- src/diffusers/loaders/lora_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 01040b06927b..b3dd200568e2 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3104,7 +3104,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False ): @@ -3116,7 +3116,7 @@ def load_lora_into_transformer( A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. - transformer (`CogVideoXTransformer3DModel`): + transformer (`MochiTransformer3DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use From ac863934870556505f6035127ed39466e57b6002 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Dec 2024 12:05:05 +0530 Subject: [PATCH 22/88] [LoRA] Support LTX Video (#10228) * add lora support for ltx * add tests * fix copied from comments * update --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 308 ++++++++++++++++++ src/diffusers/loaders/peft.py | 1 + .../models/transformers/transformer_ltx.py | 26 +- src/diffusers/pipelines/ltx/pipeline_ltx.py | 17 +- .../pipelines/ltx/pipeline_ltx_image2video.py | 17 +- tests/lora/test_lora_layers_ltx_video.py | 181 ++++++++++ 7 files changed, 543 insertions(+), 9 deletions(-) create mode 100644 tests/lora/test_lora_layers_ltx_video.py diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 007d3c95597a..d59830e614e9 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -65,6 +65,7 @@ def text_encoder_attn_modules(text_encoder): "StableDiffusionLoraLoaderMixin", "SD3LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", + "LTXVideoLoraLoaderMixin", "LoraLoaderMixin", "FluxLoraLoaderMixin", "CogVideoXLoraLoaderMixin", @@ -89,6 +90,7 @@ def text_encoder_attn_modules(text_encoder): CogVideoXLoraLoaderMixin, FluxLoraLoaderMixin, LoraLoaderMixin, + LTXVideoLoraLoaderMixin, Mochi1LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionLoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b3dd200568e2..869a5cca24f5 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3254,6 +3254,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) +class LTXVideoLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`LTXVideoTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components) + + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3851ff32ddfa..3dddb94f30c1 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -53,6 +53,7 @@ "FluxTransformer2DModel": lambda model_cls, weights: weights, "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, "MochiTransformer3DModel": lambda model_cls, weights: weights, + "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 8aa3a1590fb9..2ed8520a5d75 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -21,8 +21,8 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin -from ...utils import is_torch_version, logging +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention @@ -267,7 +267,7 @@ def forward( @maybe_allow_in_graph -class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). @@ -374,8 +374,24 @@ def forward( height: int, width: int, rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) # convert encoder_attention_mask to a bias the same way we do for attention_mask @@ -436,6 +452,10 @@ def custom_forward(*inputs): hidden_states = hidden_states * (1 + scale) + shift output = self.proj_out(hidden_states) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 72b95fea1ce1..f88fcd3e7988 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -13,14 +13,14 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...loaders import FromSingleFileMixin +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin from ...models.autoencoders import AutoencoderKLLTXVideo from ...models.transformers import LTXVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -140,7 +140,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class LTXPipeline(DiffusionPipeline, FromSingleFileMixin): +class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): r""" Pipeline for text-to-video generation. @@ -484,6 +484,10 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -510,6 +514,7 @@ def __call__( negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 128, @@ -564,6 +569,10 @@ def __call__( [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -600,6 +609,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Define call parameters @@ -701,6 +711,7 @@ def __call__( height=latent_height, width=latent_width, rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred.float() diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 25ed635a3d17..5b36e993b012 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch @@ -21,7 +21,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput -from ...loaders import FromSingleFileMixin +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin from ...models.autoencoders import AutoencoderKLLTXVideo from ...models.transformers import LTXVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -159,7 +159,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin): +class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): r""" Pipeline for image-to-video generation. @@ -543,6 +543,10 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -570,6 +574,7 @@ def __call__( negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 128, @@ -626,6 +631,10 @@ def __call__( [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -662,6 +671,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Define call parameters @@ -772,6 +782,7 @@ def __call__( height=latent_height, width=latent_width, rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred.float() diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py new file mode 100644 index 000000000000..c9c877b202ef --- /dev/null +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -0,0 +1,181 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + floats_tensor, + is_torch_version, + require_peft_backend, + skip_mps, + torch_device, +) + + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 + + +@require_peft_backend +class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = LTXPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_classes = [FlowMatchEulerDiscreteScheduler] + scheduler_kwargs = {} + + transformer_kwargs = { + "in_channels": 8, + "out_channels": 8, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 4, + "attention_head_dim": 8, + "cross_attention_dim": 32, + "num_layers": 1, + "caption_channels": 32, + } + transformer_cls = LTXVideoTransformer3DModel + vae_kwargs = { + "latent_channels": 8, + "block_out_channels": (8, 8, 8, 8), + "spatio_temporal_scaling": (True, True, False, False), + "layers_per_block": (1, 1, 1, 1, 1), + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": False, + } + vae_cls = AutoencoderKLLTXVideo + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + text_encoder_target_modules = ["q", "k", "v", "o"] + + @property + def output_shape(self): + return (1, 9, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 8 + num_frames = 9 + num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1 + latent_height = 8 + latent_width = 8 + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width)) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "dance monkey", + "num_frames": num_frames, + "num_inference_steps": 4, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "max_sequence_length": sequence_length, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + @skip_mps + @pytest.mark.xfail( + condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), + reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", + strict=True, + ) + def test_lora_fuse_nan(self): + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) + + out = pipe( + "test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np" + )[0] + + self.assertTrue(np.isnan(out).all()) + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + @unittest.skip("Not supported in LTXVideo.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in LTXVideo.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in LTXVideo.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + def test_simple_inference_with_text_lora_save_load(self): + pass From f9d5a9324d77169d486a60f3b4b267c74149b982 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Dec 2024 13:43:24 +0530 Subject: [PATCH 23/88] [docs] Clarify dtypes for Sana (#10248) update Co-authored-by: Sayak Paul --- docs/source/en/api/pipelines/sana.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/api/pipelines/sana.md b/docs/source/en/api/pipelines/sana.md index 1894aa55757e..64acb44962e6 100644 --- a/docs/source/en/api/pipelines/sana.md +++ b/docs/source/en/api/pipelines/sana.md @@ -42,6 +42,8 @@ Available models: Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) collection for more information. +Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype. + Make sure to pass the `variant` argument for downloaded checkpoints to use lower disk space. Set it to `"fp16"` for models with recommended dtype as `torch.float16`, and `"bf16"` for models with recommended dtype as `torch.bfloat16`. By default, `torch.float32` weights are downloaded, which use twice the amount of disk storage. Additionally, `torch.float32` weights can be downcasted on-the-fly by specifying the `torch_dtype` argument. Read about it in the [docs](https://huggingface.co/docs/diffusers/v0.31.0/en/api/pipelines/overview#diffusers.DiffusionPipeline.from_pretrained). From e24941b2a71cc1e163ffda1731be22bcfcc70c60 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Dec 2024 16:09:37 +0530 Subject: [PATCH 24/88] [Single File] Add GGUF support (#9964) * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * Update src/diffusers/quantizers/gguf/utils.py Co-authored-by: Sayak Paul * update * update * update * update * update * update * update * update * update * update * Update docs/source/en/quantization/gguf.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update * update * update * update --------- Co-authored-by: Sayak Paul Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .github/workflows/nightly_tests.yml | 2 + docs/source/en/_toctree.yml | 2 + docs/source/en/api/quantization.md | 3 + docs/source/en/quantization/gguf.md | 70 +++ docs/source/en/quantization/overview.md | 9 +- src/diffusers/__init__.py | 4 +- src/diffusers/loaders/single_file_model.py | 46 +- src/diffusers/loaders/single_file_utils.py | 25 +- src/diffusers/models/model_loading_utils.py | 84 +++- src/diffusers/models/modeling_utils.py | 8 +- .../models/transformers/transformer_flux.py | 1 - src/diffusers/quantizers/auto.py | 12 +- .../quantizers/bitsandbytes/bnb_quantizer.py | 5 +- src/diffusers/quantizers/gguf/__init__.py | 1 + .../quantizers/gguf/gguf_quantizer.py | 159 ++++++ src/diffusers/quantizers/gguf/utils.py | 456 ++++++++++++++++++ .../quantizers/quantization_config.py | 24 + src/diffusers/utils/__init__.py | 3 + src/diffusers/utils/constants.py | 1 + src/diffusers/utils/import_utils.py | 35 +- src/diffusers/utils/testing_utils.py | 13 + tests/quantization/gguf/test_gguf.py | 379 +++++++++++++++ 22 files changed, 1321 insertions(+), 21 deletions(-) create mode 100644 docs/source/en/quantization/gguf.md create mode 100644 src/diffusers/quantizers/gguf/__init__.py create mode 100644 src/diffusers/quantizers/gguf/gguf_quantizer.py create mode 100644 src/diffusers/quantizers/gguf/utils.py create mode 100644 tests/quantization/gguf/test_gguf.py diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index b8fbf8f54362..cc0abac6e4ab 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -357,6 +357,8 @@ jobs: config: - backend: "bitsandbytes" test_location: "bnb" + - backend: "gguf" + test_location: "gguf" runs-on: group: aws-g6e-xlarge-plus container: diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4edeb9fcb389..ab733054fbd3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -157,6 +157,8 @@ title: Getting Started - local: quantization/bitsandbytes title: bitsandbytes + - local: quantization/gguf + title: gguf - local: quantization/torchao title: torchao title: Quantization Methods diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 18aadf3111bd..168a9a03473f 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -28,6 +28,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui [[autodoc]] BitsAndBytesConfig +## GGUFQuantizationConfig + +[[autodoc]] GGUFQuantizationConfig ## TorchAoConfig [[autodoc]] TorchAoConfig diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md new file mode 100644 index 000000000000..dbcd1b1486b2 --- /dev/null +++ b/docs/source/en/quantization/gguf.md @@ -0,0 +1,70 @@ + + +# GGUF + +The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported. + +The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant. + +Before starting please install gguf in your environment + +```shell +pip install -U gguf +``` + +Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`]. + +When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.unint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`. + +The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original (`numpy`)[https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py] implementation by [compilade](https://github.com/compilade). + +```python +import torch + +from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig + +ckpt_path = ( + "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" +) +transformer = FluxTransformer2DModel.from_single_file( + ckpt_path, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16, +) +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + transformer=transformer, + generator=torch.manual_seed(0), + torch_dtype=torch.bfloat16, +) +pipe.enable_model_cpu_offload() +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt).images[0] +image.save("flux-gguf.png") +``` + +## Supported Quantization Types + +- BF16 +- Q4_0 +- Q4_1 +- Q5_0 +- Q5_1 +- Q8_0 +- Q2_K +- Q3_K +- Q4_K +- Q5_K +- Q6_K + diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 151b22a607a4..6c2df7514d5e 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a -Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method. +Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method. @@ -32,4 +32,9 @@ If you are new to the quantization field, we recommend you to check out these be ## When to use what? -Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use. \ No newline at end of file +Diffusers currently supports the following quantization methods. +- [BitsandBytes]() +- [TorchAO]() +- [GGUF]() + +[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fc7ada80a63b..e2351a0c53b8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,7 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"], + "quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -569,7 +569,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin - from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig + from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig try: if not is_onnx_available(): diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 78ce47273d8f..9641435fa5a6 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -17,8 +17,10 @@ from contextlib import nullcontext from typing import Optional +import torch from huggingface_hub.utils import validate_hf_hub_args +from ..quantizers import DiffusersAutoQuantizer from ..utils import deprecate, is_accelerate_available, logging from .single_file_utils import ( SingleFileComponentError, @@ -214,6 +216,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = subfolder = kwargs.pop("subfolder", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) + quantization_config = kwargs.pop("quantization_config", None) + device = kwargs.pop("device", None) if isinstance(pretrained_model_link_or_path_or_dict, dict): checkpoint = pretrained_model_link_or_path_or_dict @@ -227,6 +231,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = local_files_only=local_files_only, revision=revision, ) + if quantization_config is not None: + hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) + hf_quantizer.validate_environment() + + else: + hf_quantizer = None mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] @@ -309,8 +319,36 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = with ctx(): model = cls.from_config(diffusers_model_config) + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + + else: + keep_in_fp32_modules = [] + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, + device_map=None, + state_dict=diffusers_format_checkpoint, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + if is_accelerate_available(): - unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + param_device = torch.device(device) if device else torch.device("cpu") + unexpected_keys = load_model_dict_into_meta( + model, + diffusers_format_checkpoint, + dtype=torch_dtype, + device=param_device, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + ) else: _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) @@ -324,7 +362,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) - if torch_dtype is not None: + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + + if torch_dtype is not None and hf_quantizer is None: model.to(torch_dtype) model.eval() diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 21ff2841700d..4e288737fe88 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -81,8 +81,14 @@ "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight", "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", "stable_cascade_stage_c": "clip_txt_mapper.weight", - "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", - "sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight", + "sd3": [ + "joint_blocks.0.context_block.adaLN_modulation.1.bias", + "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", + ], + "sd35_large": [ + "joint_blocks.37.x_block.mlp.fc1.weight", + "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight", + ], "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", @@ -542,13 +548,20 @@ def infer_diffusers_model_type(checkpoint): ): model_type = "stable_cascade_stage_b" - elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216: - if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864: + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any( + checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"] + ): + if "model.diffusion_model.pos_embed" in checkpoint: + key = "model.diffusion_model.pos_embed" + else: + key = "pos_embed" + + if checkpoint[key].shape[1] == 36864: model_type = "sd3" - elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456: + elif checkpoint[key].shape[1] == 147456: model_type = "sd35_medium" - elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint: + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]): model_type = "sd35_large" elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 546c0eb4d840..af1a1a5250ff 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -17,6 +17,7 @@ import importlib import inspect import os +from array import array from collections import OrderedDict from pathlib import Path from typing import List, Optional, Union @@ -26,6 +27,7 @@ from huggingface_hub.utils import EntryNotFoundError from ..utils import ( + GGUF_FILE_EXTENSION, SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, WEIGHTS_INDEX_NAME, @@ -33,6 +35,8 @@ _get_model_file, deprecate, is_accelerate_available, + is_gguf_available, + is_torch_available, is_torch_version, logging, ) @@ -139,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: return safetensors.torch.load_file(checkpoint_file, device="cpu") + elif file_extension == GGUF_FILE_EXTENSION: + return load_gguf_checkpoint(checkpoint_file) else: weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} return torch.load( @@ -211,13 +217,14 @@ def load_model_dict_into_meta( set_module_kwargs["dtype"] = dtype # bnb params are flattened. + # gguf quants have a different shape based on the type of quantization applied if empty_state_dict[param_name].shape != param.shape: if ( is_quantized and hf_quantizer.pre_quantized and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) ): - hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) + hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param) else: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( @@ -396,3 +403,78 @@ def _fetch_index_file_legacy( index_file = None return index_file + + +def _gguf_parse_value(_value, data_type): + if not isinstance(data_type, list): + data_type = [data_type] + if len(data_type) == 1: + data_type = data_type[0] + array_data_type = None + else: + if data_type[0] != 9: + raise ValueError("Received multiple types, therefore expected the first type to indicate an array.") + data_type, array_data_type = data_type + + if data_type in [0, 1, 2, 3, 4, 5, 10, 11]: + _value = int(_value[0]) + elif data_type in [6, 12]: + _value = float(_value[0]) + elif data_type in [7]: + _value = bool(_value[0]) + elif data_type in [8]: + _value = array("B", list(_value)).tobytes().decode() + elif data_type in [9]: + _value = _gguf_parse_value(_value, array_data_type) + return _value + + +def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): + """ + Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config + attributes. + + Args: + gguf_checkpoint_path (`str`): + The path the to GGUF file to load + return_tensors (`bool`, defaults to `True`): + Whether to read the tensors from the file and return them. Not doing so is faster and only loads the + metadata in memory. + """ + + if is_gguf_available() and is_torch_available(): + import gguf + from gguf import GGUFReader + + from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter + else: + logger.error( + "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " + "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." + ) + raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.") + + reader = GGUFReader(gguf_checkpoint_path) + + parsed_parameters = {} + for tensor in reader.tensors: + name = tensor.name + quant_type = tensor.tensor_type + + # if the tensor is a torch supported dtype do not use GGUFParameter + is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16] + if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES: + _supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES]) + raise ValueError( + ( + f"{name} has a quantization type: {str(quant_type)} which is unsupported." + "\n\nCurrently the following quantization types are supported: \n\n" + f"{_supported_quants_str}" + "\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers" + ) + ) + + weights = torch.from_numpy(tensor.data.copy()) + parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights + + return parsed_parameters diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ce5289e3dbfd..0f9c9203c926 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1038,14 +1038,14 @@ def to(self, *args, **kwargs): dtype_present_in_args = True break - # Checks if the model has been loaded in 4-bit or 8-bit with BNB - if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_quantized", False): if dtype_present_in_args: raise ValueError( - "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" - " desired `dtype` by passing the correct `torch_dtype` argument." + "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please " + "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`" ) + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: if getattr(self, "is_loaded_in_8bit", False): raise ValueError( "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 18527e3c46c0..8dbe49b75076 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -524,7 +524,6 @@ def custom_forward(*inputs): ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 098308ae0bdc..41173ecb8f5e 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -15,23 +15,33 @@ Adapted from https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py """ + import warnings from typing import Dict, Optional, Union from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer -from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig +from .gguf import GGUFQuantizer +from .quantization_config import ( + BitsAndBytesConfig, + GGUFQuantizationConfig, + QuantizationConfigMixin, + QuantizationMethod, + TorchAoConfig, +) from .torchao import TorchAoHfQuantizer AUTO_QUANTIZER_MAPPING = { "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, + "gguf": GGUFQuantizer, "torchao": TorchAoHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { "bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig, + "gguf": GGUFQuantizationConfig, "torchao": TorchAoConfig, } diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index d5ac1611a571..f7780b66b12b 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -204,7 +204,10 @@ def create_quantized_param( module._parameters[tensor_name] = new_value - def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape): + def check_quantized_param_shape(self, param_name, current_param, loaded_param): + current_param_shape = current_param.shape + loaded_param_shape = loaded_param.shape + n = current_param_shape.numel() inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1) if loaded_param_shape != inferred_shape: diff --git a/src/diffusers/quantizers/gguf/__init__.py b/src/diffusers/quantizers/gguf/__init__.py new file mode 100644 index 000000000000..b3d9082ac803 --- /dev/null +++ b/src/diffusers/quantizers/gguf/__init__.py @@ -0,0 +1 @@ +from .gguf_quantizer import GGUFQuantizer diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py new file mode 100644 index 000000000000..0c760e277ce4 --- /dev/null +++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py @@ -0,0 +1,159 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +from ...utils import ( + get_module_from_name, + is_accelerate_available, + is_accelerate_version, + is_gguf_available, + is_gguf_version, + is_torch_available, + logging, +) + + +if is_torch_available() and is_gguf_available(): + import torch + + from .utils import ( + GGML_QUANT_SIZES, + GGUFParameter, + _dequantize_gguf_and_restore_linear, + _quant_shape_from_byte_shape, + _replace_with_gguf_linear, + ) + + +logger = logging.get_logger(__name__) + + +class GGUFQuantizer(DiffusersQuantizer): + use_keep_in_fp32_modules = True + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + self.compute_dtype = quantization_config.compute_dtype + self.pre_quantized = quantization_config.pre_quantized + self.modules_to_not_convert = quantization_config.modules_to_not_convert + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + def validate_environment(self, *args, **kwargs): + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`" + ) + if not is_gguf_available() or is_gguf_version("<", "0.10.0"): + raise ImportError( + "To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`" + ) + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if target_dtype != torch.uint8: + logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization") + return torch.uint8 + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + torch_dtype = self.compute_dtype + return torch_dtype + + def check_quantized_param_shape(self, param_name, current_param, loaded_param): + loaded_param_shape = loaded_param.shape + current_param_shape = current_param.shape + quant_type = loaded_param.quant_type + + block_size, type_size = GGML_QUANT_SIZES[quant_type] + + inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size) + if inferred_shape != current_param_shape: + raise ValueError( + f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}" + ) + + return True + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: Union["GGUFParameter", "torch.Tensor"], + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + if isinstance(param_value, GGUFParameter): + return True + + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: Union["GGUFParameter", "torch.Tensor"], + param_name: str, + target_device: "torch.device", + state_dict: Optional[Dict[str, Any]] = None, + unexpected_keys: Optional[List[str]] = None, + ): + module, tensor_name = get_module_from_name(model, param_name) + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + if tensor_name in module._parameters: + module._parameters[tensor_name] = param_value.to(target_device) + if tensor_name in module._buffers: + module._buffers[tensor_name] = param_value.to(target_device) + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + state_dict = kwargs.get("state_dict", None) + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + _replace_with_gguf_linear( + model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert + ) + + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + return model + + @property + def is_serializable(self): + return False + + @property + def is_trainable(self) -> bool: + return False + + def _dequantize(self, model): + is_model_on_cpu = model.device.type == "cpu" + if is_model_on_cpu: + logger.info( + "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." + ) + model.to(torch.cuda.current_device()) + + model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert) + if is_model_on_cpu: + model.to("cpu") + return model diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py new file mode 100644 index 000000000000..35e5743fbcf0 --- /dev/null +++ b/src/diffusers/quantizers/gguf/utils.py @@ -0,0 +1,456 @@ +# Copyright 2024 The HuggingFace Team and City96. All rights reserved. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + + +import inspect +from contextlib import nullcontext + +import gguf +import torch +import torch.nn as nn + +from ...utils import is_accelerate_available + + +if is_accelerate_available(): + import accelerate + from accelerate import init_empty_weights + from accelerate.hooks import add_hook_to_module, remove_hook_from_module + + +# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook +def _create_accelerate_new_hook(old_hook): + r""" + Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: + https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with + some changes + """ + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) + old_hook_attr = old_hook.__dict__ + filtered_old_hook_attr = {} + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) + for k in old_hook_attr.keys(): + if k in old_hook_init_signature.parameters: + filtered_old_hook_attr[k] = old_hook_attr[k] + new_hook = old_hook_cls(**filtered_old_hook_attr) + return new_hook + + +def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]): + def _should_convert_to_gguf(state_dict, prefix): + weight_key = prefix + "weight" + return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter) + + has_children = list(model.children()) + if not has_children: + return + + for name, module in model.named_children(): + module_prefix = prefix + name + "." + _replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert) + + if ( + isinstance(module, nn.Linear) + and _should_convert_to_gguf(state_dict, module_prefix) + and name not in modules_to_not_convert + ): + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model._modules[name] = GGUFLinear( + module.in_features, + module.out_features, + module.bias is not None, + compute_dtype=compute_dtype, + ) + model._modules[name].source_cls = type(module) + # Force requires_grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + + return model + + +def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]): + for name, module in model.named_children(): + if isinstance(module, GGUFLinear) and name not in modules_to_not_convert: + device = module.weight.device + bias = getattr(module, "bias", None) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + new_module = nn.Linear( + module.in_features, + module.out_features, + module.bias is not None, + device=device, + ) + new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight)) + if bias is not None: + new_module.bias = bias + + # Create a new hook and attach it in case we use accelerate + if hasattr(module, "_hf_hook"): + old_hook = module._hf_hook + new_hook = _create_accelerate_new_hook(old_hook) + + remove_hook_from_module(module) + add_hook_to_module(new_module, new_hook) + + new_module.to(device) + model._modules[name] = new_module + + has_children = list(module.children()) + if has_children: + _dequantize_gguf_and_restore_linear(module, modules_to_not_convert) + + return model + + +# dequantize operations based on torch ports of GGUF dequantize_functions +# from City96 +# more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py + + +QK_K = 256 +K_SCALE_SIZE = 12 + + +def to_uint32(x): + x = x.view(torch.uint8).to(torch.int32) + return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) + + +def split_block_dims(blocks, *args): + n_max = blocks.shape[1] + dims = list(args) + [n_max - sum(args)] + return torch.split(blocks, dims, dim=1) + + +def get_scale_min(scales): + n_blocks = scales.shape[0] + scales = scales.view(torch.uint8) + scales = scales.reshape((n_blocks, 3, 4)) + + d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) + + sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) + min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) + + return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) + + +def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None): + d, x = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + x = x.view(torch.int8) + return d * x + + +def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape((n_blocks, -1)) + + qs = ql | (qh << 4) + return (d * qs) + m + + +def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qh, qs = split_block_dims(blocks, 2, 4) + d = d.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape(n_blocks, -1) + + qs = (ql | (qh << 4)).to(torch.int8) - 16 + return d * qs + + +def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, m, qs = split_block_dims(blocks, 2, 2) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + qs = (qs & 0x0F).reshape(n_blocks, -1) + + return (d * qs) + m + + +def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + return d * qs + + +def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + ( + ql, + qh, + scales, + d, + ) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) + + scales = scales.view(torch.int8).to(dtype) + d = d.view(torch.float16).to(dtype) + d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) + + ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 4, 1) + ) + qh = (qh & 0x03).reshape((n_blocks, -1, 32)) + q = (ql | (qh << 4)).to(torch.int8) - 32 + q = q.reshape((n_blocks, QK_K // 16, -1)) + + return (d * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) + + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape( + (1, 1, 8, 1) + ) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = (qh & 0x01).reshape((n_blocks, -1, 32)) + q = ql | (qh << 4) + + return (d * q - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) + + return (d * qs - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) + d = d.view(torch.float16).to(dtype) + + lscales, hscales = scales[:, :8], scales[:, 8:] + lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 2, 1) + ) + lscales = lscales.reshape((n_blocks, 16)) + hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor( + [0, 2, 4, 6], device=d.device, dtype=torch.uint8 + ).reshape((1, 4, 1)) + hscales = hscales.reshape((n_blocks, 16)) + scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) + scales = scales.to(torch.int8) - 32 + + dl = (d * scales).reshape((n_blocks, 16, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 4, 1) + ) + qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape( + (1, 1, 8, 1) + ) + ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 + qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 + q = ql.to(torch.int8) - (qh << 2).to(torch.int8) + + return (dl * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + # (n_blocks, 16, 1) + dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) + ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) + + shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) + + qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 + qs = qs.reshape((n_blocks, QK_K // 16, 16)) + qs = dl * qs - ml + + return qs.reshape((n_blocks, -1)) + + +def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None): + return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) + + +GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES +dequantize_functions = { + gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, + gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, + gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, + gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, + gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, + gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, + gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, + gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, + gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, + gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, + gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, +} +SUPPORTED_GGUF_QUANT_TYPES = list(dequantize_functions.keys()) + + +def _quant_shape_from_byte_shape(shape, type_size, block_size): + return (*shape[:-1], shape[-1] // type_size * block_size) + + +def dequantize_gguf_tensor(tensor): + if not hasattr(tensor, "quant_type"): + return tensor + + quant_type = tensor.quant_type + dequant_fn = dequantize_functions[quant_type] + + block_size, type_size = GGML_QUANT_SIZES[quant_type] + + tensor = tensor.view(torch.uint8) + shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size) + + n_blocks = tensor.numel() // type_size + blocks = tensor.reshape((n_blocks, type_size)) + + dequant = dequant_fn(blocks, block_size, type_size) + dequant = dequant.reshape(shape) + + return dequant.as_tensor() + + +class GGUFParameter(torch.nn.Parameter): + def __new__(cls, data, requires_grad=False, quant_type=None): + data = data if data is not None else torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.quant_type = quant_type + + return self + + def as_tensor(self): + return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + result = super().__torch_function__(func, types, args, kwargs) + + # When converting from original format checkpoints we often use splits, cats etc on tensors + # this method ensures that the returned tensor type from those operations remains GGUFParameter + # so that we preserve quant_type information + quant_type = None + for arg in args: + if isinstance(arg, list) and (arg[0], GGUFParameter): + quant_type = arg[0].quant_type + break + if isinstance(arg, GGUFParameter): + quant_type = arg.quant_type + break + if isinstance(result, torch.Tensor): + return cls(result, quant_type=quant_type) + # Handle tuples and lists + elif isinstance(result, (tuple, list)): + # Preserve the original type (tuple or list) + wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result] + return type(result)(wrapped) + else: + return result + + +class GGUFLinear(nn.Linear): + def __init__( + self, + in_features, + out_features, + bias=False, + compute_dtype=None, + device=None, + ) -> None: + super().__init__(in_features, out_features, bias, device) + self.compute_dtype = compute_dtype + + def forward(self, inputs): + weight = dequantize_gguf_tensor(self.weight) + weight = weight.to(self.compute_dtype) + bias = self.bias.to(self.compute_dtype) + + output = torch.nn.functional.linear(inputs, weight, bias) + return output diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 4aeb75ab704c..3078be310719 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -43,6 +43,7 @@ class QuantizationMethod(str, Enum): BITS_AND_BYTES = "bitsandbytes" + GGUF = "gguf" TORCHAO = "torchao" @@ -394,6 +395,29 @@ def to_diff_dict(self) -> Dict[str, Any]: return serializable_config_dict +@dataclass +class GGUFQuantizationConfig(QuantizationConfigMixin): + """This is a config class for GGUF Quantization techniques. + + Args: + compute_dtype: (`torch.dtype`, defaults to `torch.float32`): + This sets the computational type which might be different than the input type. For example, inputs might be + fp32, but computation can be set to bf16 for speedups. + + """ + + def __init__(self, compute_dtype: Optional["torch.dtype"] = None): + self.quant_method = QuantizationMethod.GGUF + self.compute_dtype = compute_dtype + self.pre_quantized = True + + # TODO: (Dhruv) Add this as an init argument when we can support loading unquantized checkpoints. + self.modules_to_not_convert = None + + if self.compute_dtype is None: + self.compute_dtype = torch.float32 + + @dataclass class TorchAoConfig(QuantizationConfigMixin): """This is a config class for torchao quantization/sparsity techniques. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 9860ac849834..f8de48ecfc78 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -23,6 +23,7 @@ DEPRECATED_REVISION_ARGS, DIFFUSERS_DYNAMIC_MODULE_NAME, FLAX_WEIGHTS_NAME, + GGUF_FILE_EXTENSION, HF_MODULES_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, MIN_PEFT_VERSION, @@ -66,6 +67,8 @@ is_bs4_available, is_flax_available, is_ftfy_available, + is_gguf_available, + is_gguf_version, is_google_colab, is_inflect_available, is_invisible_watermark_available, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 553ac5d1bb27..93b0cd847d91 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -34,6 +34,7 @@ SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json" SAFETENSORS_FILE_EXTENSION = "safetensors" +GGUF_FILE_EXTENSION = "gguf" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f325f36bddd3..3014efebc82e 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -339,6 +339,14 @@ def is_timm_available(): except importlib_metadata.PackageNotFoundError: _imageio_available = False +_is_gguf_available = importlib.util.find_spec("gguf") is not None +if _is_gguf_available: + try: + _gguf_version = importlib_metadata.version("gguf") + logger.debug(f"Successfully import gguf version {_gguf_version}") + except importlib_metadata.PackageNotFoundError: + _is_gguf_available = False + _is_torchao_available = importlib.util.find_spec("torchao") is not None if _is_torchao_available: @@ -469,6 +477,10 @@ def is_imageio_available(): return _imageio_available +def is_gguf_available(): + return _is_gguf_available + + def is_torchao_available(): return _is_torchao_available @@ -607,8 +619,13 @@ def is_torchao_available(): """ # docstyle-ignore +GGUF_IMPORT_ERROR = """ +{0} requires the gguf library but it was not found in your environment. You can install it with pip: `pip install gguf` +""" + TORCHAO_IMPORT_ERROR = """ -{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao` +{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install +torchao` """ BACKENDS_MAPPING = OrderedDict( @@ -636,6 +653,7 @@ def is_torchao_available(): ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), + ("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)), ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)), ] ) @@ -793,6 +811,21 @@ def is_bitsandbytes_version(operation: str, version: str): return compare_versions(parse(_bitsandbytes_version), operation, version) +def is_gguf_version(operation: str, version: str): + """ + Compares the current Accelerate version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _is_gguf_available: + return False + return compare_versions(parse(_gguf_version), operation, version) + + def is_k_diffusion_version(operation: str, version: str): """ Compares the current k-diffusion version to a given reference with an operation. diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index b4d3415de50e..3448b4d28d1f 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -32,6 +32,7 @@ is_bitsandbytes_available, is_compel_available, is_flax_available, + is_gguf_available, is_note_seq_available, is_onnx_available, is_opencv_available, @@ -477,6 +478,18 @@ def decorator(test_case): return decorator +def require_gguf_version_greater_or_equal(gguf_version): + def decorator(test_case): + correct_gguf_version = is_gguf_available() and version.parse( + version.parse(importlib.metadata.version("gguf")).base_version + ) >= version.parse(gguf_version) + return unittest.skipUnless( + correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}." + )(test_case) + + return decorator + + def require_torchao_version_greater(torchao_version): def decorator(test_case): correct_torchao_version = is_torchao_available() and version.parse( diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py new file mode 100644 index 000000000000..8ac4c9915c27 --- /dev/null +++ b/tests/quantization/gguf/test_gguf.py @@ -0,0 +1,379 @@ +import gc +import unittest + +import numpy as np +import torch +import torch.nn as nn + +from diffusers import ( + FluxPipeline, + FluxTransformer2DModel, + GGUFQuantizationConfig, + SD3Transformer2DModel, + StableDiffusion3Pipeline, +) +from diffusers.utils.testing_utils import ( + is_gguf_available, + nightly, + numpy_cosine_similarity_distance, + require_accelerate, + require_big_gpu_with_torch_cuda, + require_gguf_version_greater_or_equal, + torch_device, +) + + +if is_gguf_available(): + from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter + + +@nightly +@require_big_gpu_with_torch_cuda +@require_accelerate +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFSingleFileTesterMixin: + ckpt_path = None + model_cls = None + torch_dtype = torch.bfloat16 + expected_memory_use_in_gb = 5 + + def test_gguf_parameters(self): + quant_storage_type = torch.uint8 + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + + for param_name, param in model.named_parameters(): + if isinstance(param, GGUFParameter): + assert hasattr(param, "quant_type") + assert param.dtype == quant_storage_type + + def test_gguf_linear_layers(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"): + assert module.weight.dtype == torch.uint8 + assert module.bias.dtype == torch.float32 + + def test_gguf_memory_usage(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + + model = self.model_cls.from_single_file( + self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype + ) + model.to("cuda") + assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb + inputs = self.get_dummy_inputs() + + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + with torch.no_grad(): + model(**inputs) + max_memory = torch.cuda.max_memory_allocated() + assert (max_memory / 1024**3) < self.expected_memory_use_in_gb + + def test_keep_modules_in_fp32(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. + Also ensures if inference works. + """ + _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules + self.model_cls._keep_in_fp32_modules = ["proj_out"] + + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + assert module.weight.dtype == torch.float32 + self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules + + def test_dtype_assignment(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + + with self.assertRaises(ValueError): + # Tries with a `dtype` + model.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` and `dtype` + model.to(device="cuda:0", dtype=torch.float16) + + with self.assertRaises(ValueError): + # Tries with a cast + model.float() + + with self.assertRaises(ValueError): + # Tries with a cast + model.half() + + # This should work + model.to("cuda") + + def test_dequantize_model(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + model.dequantize() + + def _check_for_gguf_linear(model): + has_children = list(model.children()) + if not has_children: + return + + for name, module in model.named_children(): + if isinstance(module, nn.Linear): + assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear" + assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter" + + for name, module in model.named_children(): + _check_for_gguf_linear(module) + + +class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" + torch_dtype = torch.bfloat16 + model_cls = FluxTransformer2DModel + expected_memory_use_in_gb = 5 + + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "pooled_projections": torch.randn( + (1, 768), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + "img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype), + } + + def test_pipeline_inference(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + transformer = self.model_cls.from_single_file( + self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype + ) + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype + ) + pipe.enable_model_cpu_offload() + + prompt = "a cat holding a sign that says hello" + output = pipe( + prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np" + ).images[0] + output_slice = output[:3, :3, :].flatten() + expected_slice = np.array( + [ + 0.47265625, + 0.43359375, + 0.359375, + 0.47070312, + 0.421875, + 0.34375, + 0.46875, + 0.421875, + 0.34765625, + 0.46484375, + 0.421875, + 0.34179688, + 0.47070312, + 0.42578125, + 0.34570312, + 0.46875, + 0.42578125, + 0.3515625, + 0.45507812, + 0.4140625, + 0.33984375, + 0.4609375, + 0.41796875, + 0.34375, + 0.45898438, + 0.41796875, + 0.34375, + ] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) + assert max_diff < 1e-4 + + +class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf" + torch_dtype = torch.bfloat16 + model_cls = SD3Transformer2DModel + expected_memory_use_in_gb = 5 + + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "pooled_projections": torch.randn( + (1, 2048), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + def test_pipeline_inference(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + transformer = self.model_cls.from_single_file( + self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype + ) + pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-large", transformer=transformer, torch_dtype=self.torch_dtype + ) + pipe.enable_model_cpu_offload() + + prompt = "a cat holding a sign that says hello" + output = pipe( + prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np" + ).images[0] + output_slice = output[:3, :3, :].flatten() + expected_slice = np.array( + [ + 0.17578125, + 0.27539062, + 0.27734375, + 0.11914062, + 0.26953125, + 0.25390625, + 0.109375, + 0.25390625, + 0.25, + 0.15039062, + 0.26171875, + 0.28515625, + 0.13671875, + 0.27734375, + 0.28515625, + 0.12109375, + 0.26757812, + 0.265625, + 0.16210938, + 0.29882812, + 0.28515625, + 0.15625, + 0.30664062, + 0.27734375, + 0.14648438, + 0.29296875, + 0.26953125, + ] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) + assert max_diff < 1e-4 + + +class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf" + torch_dtype = torch.bfloat16 + model_cls = SD3Transformer2DModel + expected_memory_use_in_gb = 2 + + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "pooled_projections": torch.randn( + (1, 2048), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + def test_pipeline_inference(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + transformer = self.model_cls.from_single_file( + self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype + ) + pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-medium", transformer=transformer, torch_dtype=self.torch_dtype + ) + pipe.enable_model_cpu_offload() + + prompt = "a cat holding a sign that says hello" + output = pipe( + prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np" + ).images[0] + output_slice = output[:3, :3, :].flatten() + expected_slice = np.array( + [ + 0.625, + 0.6171875, + 0.609375, + 0.65625, + 0.65234375, + 0.640625, + 0.6484375, + 0.640625, + 0.625, + 0.6484375, + 0.63671875, + 0.6484375, + 0.66796875, + 0.65625, + 0.65234375, + 0.6640625, + 0.6484375, + 0.6328125, + 0.6640625, + 0.6484375, + 0.640625, + 0.67578125, + 0.66015625, + 0.62109375, + 0.671875, + 0.65625, + 0.62109375, + ] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) + assert max_diff < 1e-4 From 128b96f369d7433279cd49b051fd50c87d918507 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Dec 2024 19:40:00 +0530 Subject: [PATCH 25/88] Fix Mochi Quality Issues (#10033) * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * Update src/diffusers/models/transformers/transformer_mochi.py Co-authored-by: Aryan --------- Co-authored-by: Sayak Paul Co-authored-by: Aryan --- src/diffusers/models/attention_processor.py | 261 ++++++++++++------ src/diffusers/models/embeddings.py | 1 - src/diffusers/models/normalization.py | 57 ++-- .../models/transformers/transformer_mochi.py | 149 ++++++++-- src/diffusers/pipelines/ltx/pipeline_ltx.py | 1 - .../pipelines/ltx/pipeline_ltx_image2video.py | 1 - .../pipelines/mochi/pipeline_mochi.py | 26 +- 7 files changed, 337 insertions(+), 159 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index be8d654ca66a..05cbaa40e693 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -906,6 +906,177 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.processor(self, hidden_states) +class MochiAttention(nn.Module): + def __init__( + self, + query_dim: int, + added_kv_proj_dim: int, + processor: "MochiAttnProcessor2_0", + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_proj_bias: bool = True, + out_dim: Optional[int] = None, + out_context_dim: Optional[int] = None, + out_bias: bool = True, + context_pre_only: bool = False, + eps: float = 1e-5, + ): + super().__init__() + from .normalization import MochiRMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim else query_dim + self.context_pre_only = context_pre_only + + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.norm_q = MochiRMSNorm(dim_head, eps, True) + self.norm_k = MochiRMSNorm(dim_head, eps, True) + self.norm_added_q = MochiRMSNorm(dim_head, eps, True) + self.norm_added_k = MochiRMSNorm(dim_head, eps, True) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + +class MochiAttnProcessor2_0: + """Attention processor used in Mochi.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "MochiAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + if image_rotary_emb is not None: + + def apply_rotary_emb(x, freqs_cos, freqs_sin): + x_even = x[..., 0::2].float() + x_odd = x[..., 1::2].float() + + cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) + sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) + + return torch.stack([cos, sin], dim=-1).flatten(-2) + + query = apply_rotary_emb(query, *image_rotary_emb) + key = apply_rotary_emb(key, *image_rotary_emb) + + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + encoder_query, encoder_key, encoder_value = ( + encoder_query.transpose(1, 2), + encoder_key.transpose(1, 2), + encoder_value.transpose(1, 2), + ) + + sequence_length = query.size(2) + encoder_sequence_length = encoder_query.size(2) + total_length = sequence_length + encoder_sequence_length + + batch_size, heads, _, dim = query.shape + attn_outputs = [] + for idx in range(batch_size): + mask = attention_mask[idx][None, :] + valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() + + valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :] + valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :] + valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :] + + valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2) + valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2) + valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2) + + attn_output = F.scaled_dot_product_attention( + valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False + ) + valid_sequence_length = attn_output.size(2) + attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) + attn_outputs.append(attn_output) + + hidden_states = torch.cat(attn_outputs, dim=0) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length, encoder_sequence_length), dim=1 + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if hasattr(attn, "to_add_out"): + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + class AttnProcessor: r""" Default processor for performing attention-related computations. @@ -3868,94 +4039,6 @@ def __call__( return hidden_states -class MochiAttnProcessor2_0: - """Attention processor used in Mochi.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - query = query.unflatten(2, (attn.heads, -1)) - key = key.unflatten(2, (attn.heads, -1)) - value = value.unflatten(2, (attn.heads, -1)) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - - encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) - encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) - encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) - - if attn.norm_added_q is not None: - encoder_query = attn.norm_added_q(encoder_query) - if attn.norm_added_k is not None: - encoder_key = attn.norm_added_k(encoder_key) - - if image_rotary_emb is not None: - - def apply_rotary_emb(x, freqs_cos, freqs_sin): - x_even = x[..., 0::2].float() - x_odd = x[..., 1::2].float() - - cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) - sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) - - return torch.stack([cos, sin], dim=-1).flatten(-2) - - query = apply_rotary_emb(query, *image_rotary_emb) - key = apply_rotary_emb(key, *image_rotary_emb) - - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - encoder_query, encoder_key, encoder_value = ( - encoder_query.transpose(1, 2), - encoder_key.transpose(1, 2), - encoder_value.transpose(1, 2), - ) - - sequence_length = query.size(2) - encoder_sequence_length = encoder_query.size(2) - - query = torch.cat([query, encoder_query], dim=2) - key = torch.cat([key, encoder_key], dim=2) - value = torch.cat([value, encoder_value], dim=2) - - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - - hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( - (sequence_length, encoder_sequence_length), dim=1 - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if getattr(attn, "to_add_out", None) is not None: - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - - class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses @@ -5668,13 +5751,13 @@ def __call__( AttnProcessorNPU, AttnProcessor2_0, MochiVaeAttnProcessor2_0, + MochiAttnProcessor2_0, StableAudioAttnProcessor2_0, HunyuanAttnProcessor2_0, FusedHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, - MochiAttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b423c17c1246..0f4b555a2d71 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -542,7 +542,6 @@ def forward(self, latent): height, width = latent.shape[-2:] else: height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size - latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 264de4d18d03..fe3823e32acf 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -234,33 +234,6 @@ def forward( return x, gate_msa, scale_mlp, gate_mlp -class MochiRMSNormZero(nn.Module): - r""" - Adaptive RMS Norm used in Mochi. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - """ - - def __init__( - self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False - ) -> None: - super().__init__() - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, hidden_dim) - self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) - - def forward( - self, hidden_states: torch.Tensor, emb: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - emb = self.linear(self.silu(emb)) - scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) - hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) - - return hidden_states, gate_msa, scale_mlp, gate_mlp - - class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). @@ -549,6 +522,36 @@ def forward(self, hidden_states): return hidden_states +# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported +# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013 +class MochiRMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + hidden_states = hidden_states * self.weight + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + class GlobalResponseNorm(nn.Module): # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 def __init__(self, dim): diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index c74c25895cd3..fe72dc56883e 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -23,16 +23,96 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import Attention, MochiAttnProcessor2_0 +from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm +from ..normalization import AdaLayerNormContinuous, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class MochiModulatedRMSNorm(nn.Module): + def __init__(self, eps: float): + super().__init__() + + self.eps = eps + self.norm = RMSNorm(0, eps, False) + + def forward(self, hidden_states, scale=None): + hidden_states_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + hidden_states = self.norm(hidden_states) + + if scale is not None: + hidden_states = hidden_states * scale + + hidden_states = hidden_states.to(hidden_states_dtype) + + return hidden_states + + +class MochiLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + eps=1e-5, + bias=True, + ): + super().__init__() + + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + self.norm = MochiModulatedRMSNorm(eps=eps) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + input_dtype = x.dtype + + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32))) + + return x.to(input_dtype) + + +class MochiRMSNormZero(nn.Module): + r""" + Adaptive RMS Norm used in Mochi. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, hidden_dim) + self.norm = RMSNorm(0, eps, False) + + def forward( + self, hidden_states: torch.Tensor, emb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states_dtype = hidden_states.dtype + + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32)) + hidden_states = hidden_states.to(hidden_states_dtype) + + return hidden_states, gate_msa, scale_mlp, gate_mlp + + @maybe_allow_in_graph class MochiTransformerBlock(nn.Module): r""" @@ -77,38 +157,32 @@ def __init__( if not context_pre_only: self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) else: - self.norm1_context = LuminaLayerNormContinuous( + self.norm1_context = MochiLayerNormContinuous( embedding_dim=pooled_projection_dim, conditioning_embedding_dim=dim, eps=eps, - elementwise_affine=False, - norm_type="rms_norm", - out_dim=None, ) - self.attn1 = Attention( + self.attn1 = MochiAttention( query_dim=dim, - cross_attention_dim=None, heads=num_attention_heads, dim_head=attention_head_dim, bias=False, - qk_norm=qk_norm, added_kv_proj_dim=pooled_projection_dim, added_proj_bias=False, out_dim=dim, out_context_dim=pooled_projection_dim, context_pre_only=context_pre_only, processor=MochiAttnProcessor2_0(), - eps=eps, - elementwise_affine=True, + eps=1e-5, ) # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True - self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm2 = MochiModulatedRMSNorm(eps=eps) + self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None - self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm3 = MochiModulatedRMSNorm(eps) + self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff_context = None @@ -120,14 +194,15 @@ def __init__( bias=False, ) - self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm4 = MochiModulatedRMSNorm(eps=eps) + self.norm4_context = MochiModulatedRMSNorm(eps=eps) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, + encoder_attention_mask: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) @@ -143,22 +218,25 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=encoder_attention_mask, ) - hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) - norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)) + norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32))) ff_output = self.ff(norm_hidden_states) - hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1) + hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)) if not self.context_pre_only: encoder_hidden_states = encoder_hidden_states + self.norm2_context( - context_attn_hidden_states - ) * torch.tanh(enc_gate_msa).unsqueeze(1) - norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) + context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1) + ) + norm_encoder_hidden_states = self.norm3_context( + encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32)) + ) context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh( - enc_gate_mlp - ).unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + self.norm4_context( + context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1) + ) return hidden_states, encoder_hidden_states @@ -203,7 +281,10 @@ def _get_positions( return positions def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: - freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float()) + with torch.autocast(freqs.device.type, torch.float32): + # Always run ROPE freqs computation in FP32 + freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32)) + freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin @@ -309,7 +390,11 @@ def __init__( ) self.norm_out = AdaLayerNormContinuous( - inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm" + inner_dim, + inner_dim, + elementwise_affine=False, + eps=1e-6, + norm_type="layer_norm", ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) @@ -350,7 +435,10 @@ def forward( post_patch_width = width // p temb, encoder_hidden_states = self.time_embed( - timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype + timestep, + encoder_hidden_states, + encoder_attention_mask, + hidden_dtype=hidden_states.dtype, ) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) @@ -381,6 +469,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, temb, + encoder_attention_mask, image_rotary_emb, **ckpt_kwargs, ) @@ -389,9 +478,9 @@ def custom_forward(*inputs): hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, + encoder_attention_mask=encoder_attention_mask, image_rotary_emb=image_rotary_emb, ) - hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index f88fcd3e7988..543af08f2e3c 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -198,7 +198,6 @@ def __init__( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128 ) - # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 5b36e993b012..6d2afc56ed39 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -221,7 +221,6 @@ def __init__( self.default_width = 704 self.default_frames = 121 - # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 8159c6e16bbb..dfc0a9be278d 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -210,7 +210,6 @@ def __init__( self.default_height = 480 self.default_width = 848 - # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -233,9 +232,13 @@ def _get_t5_prompt_embeds( add_special_tokens=True, return_tensors="pt", ) + text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.bool().to(device) + if prompt == "" or prompt[-1] == "": + text_input_ids = torch.zeros_like(text_input_ids, device=device) + prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids @@ -246,7 +249,7 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -451,7 +454,8 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) + latents = latents.to(dtype) return latents @property @@ -483,7 +487,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_frames: int = 19, - num_inference_steps: int = 28, + num_inference_steps: int = 64, timesteps: List[int] = None, guidance_scale: float = 4.5, num_videos_per_prompt: Optional[int] = 1, @@ -605,7 +609,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # 3. Prepare text embeddings ( prompt_embeds, @@ -624,10 +627,6 @@ def __call__( max_sequence_length=max_sequence_length, device=device, ) - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( @@ -642,6 +641,10 @@ def __call__( latents, ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + # 5. Prepare timestep # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 threshold_noise = 0.025 @@ -676,6 +679,8 @@ def __call__( attention_kwargs=attention_kwargs, return_dict=False, )[0] + # Mochi CFG + Sampling runs in FP32 + noise_pred = noise_pred.to(torch.float32) if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -683,7 +688,8 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0] + latents = latents.to(latents_dtype) if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From 1524781b88ac1a082e755a030ba9d73cd6948e84 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Dec 2024 21:43:15 +0530 Subject: [PATCH 26/88] [tests] Remove/rename unsupported quantization torchao type (#10263) update --- tests/quantization/torchao/test_torchao.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 5c71fc4e0ae7..58c1d3613daf 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -228,8 +228,7 @@ def test_quantization(self): ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), - ("int_a8w8", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("uint_a16w7", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ] if TorchAoConfig._is_cuda_capability_atleast_8_9(): @@ -253,8 +252,8 @@ def test_quantization(self): for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: quant_kwargs = {} - if quantization_name in ["uint4wo", "uint_a16w7"]: - # The dummy flux model that we use requires us to impose some restrictions on group_size here + if quantization_name in ["uint4wo", "uint7wo"]: + # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here quant_kwargs.update({"group_size": 16}) quantization_config = TorchAoConfig( quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs From 2739241ad189aef9372394a185b864cbbb9ab5a8 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:26:45 -0800 Subject: [PATCH 27/88] [docs] delete_adapters() (#10245) delete_adapters Co-authored-by: Sayak Paul --- .../en/tutorials/using_peft_for_inference.md | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index 615af55ef5b5..838271360166 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -56,7 +56,7 @@ image With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`. -The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method: +The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~PeftAdapterMixin.set_adapters`] method: ```python pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") @@ -85,7 +85,7 @@ By default, if the most up-to-date versions of PEFT and Transformers are detecte You can also merge different adapter checkpoints for inference to blend their styles together. -Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged. +Once again, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged. ```python pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0]) @@ -114,7 +114,7 @@ Impressive! As you can see, the model generated an image that mixed the characte > [!TIP] > Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide! -To return to only using one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter: +To return to only using one adapter, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter: ```python pipe.set_adapters("toy") @@ -127,7 +127,7 @@ image = pipe( image ``` -Or to disable all adapters entirely, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method to return the base model. +Or to disable all adapters entirely, use the [`~PeftAdapterMixin.disable_lora`] method to return the base model. ```python pipe.disable_lora() @@ -140,7 +140,8 @@ image ![no-lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_20_1.png) ### Customize adapters strength -For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`]. + +For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~PeftAdapterMixin.set_adapters`]. For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts: ```python @@ -195,7 +196,7 @@ image ![block-lora-mixed](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mixed.png) -## Manage active adapters +## Manage adapters You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.StableDiffusionLoraLoaderMixin.get_active_adapters`] method to check the list of active adapters: @@ -212,3 +213,11 @@ list_adapters_component_wise = pipe.get_list_adapters() list_adapters_component_wise {"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]} ``` + +The [`~PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model. + +```py +pipe.delete_adapters("toy") +pipe.get_active_adapters() +["pixel"] +``` From 9c68c945e9527eccda88bdde5d6494c911b1aa47 Mon Sep 17 00:00:00 2001 From: cjkangme Date: Wed, 18 Dec 2024 06:09:50 +0900 Subject: [PATCH 28/88] [Community Pipeline] Fix typo that cause error on regional prompting pipeline (#10251) fix: fix typo that cause error --- examples/community/regional_prompting_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index 95f6cebb0190..9f09b4bd2bba 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -129,7 +129,7 @@ def __call__( self.power = int(rp_args["power"]) if "power" in rp_args else 1 prompts = prompt if isinstance(prompt, list) else [prompt] - n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt] + n_prompts = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] self.batch = batch = num_images_per_prompt * len(prompts) if use_base: From ec1c7a793f9cdcb924d302f121348d9bb5256597 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 17 Dec 2024 21:40:09 +0000 Subject: [PATCH 29/88] Add `set_shift` to FlowMatchEulerDiscreteScheduler (#10269) --- .../scheduling_flow_match_euler_discrete.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 6ddd9ac23009..c7474d56c708 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -99,10 +99,19 @@ def __init__( self._step_index = None self._begin_index = None + self._shift = shift + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + @property def step_index(self): """ @@ -128,6 +137,9 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index + def set_shift(self, shift: float): + self._shift = shift + def scale_noise( self, sample: torch.FloatTensor, @@ -236,7 +248,7 @@ def set_timesteps( if self.config.use_dynamic_shifting: sigmas = self.time_shift(mu, 1.0, sigmas) else: - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) if self.config.shift_terminal: sigmas = self.stretch_shift_to_terminal(sigmas) From 9408aa2dfc215c77ca40dd89fe4fc33f0d3826b5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 18 Dec 2024 08:22:31 +0530 Subject: [PATCH 30/88] [LoRA] feat: lora support for SANA. (#10234) * feat: lora support for SANA. * make fix-copies * rename test class. * attention_kwargs -> cross_attention_kwargs. * Revert "attention_kwargs -> cross_attention_kwargs." This reverts commit 23433bf9bccc12e0f2f55df26bae58a894e8b43b. * exhaust 119 max line limit * sana lora fine-tuning script. * readme * add a note about the supported models. * Apply suggestions from code review Co-authored-by: Aryan * style * docs for attention_kwargs. * remove lora_scale from pag pipeline. * copy fix --------- Co-authored-by: Aryan --- examples/dreambooth/REAMDE_sana.md | 127 ++ examples/dreambooth/requirements_sana.txt | 8 + .../dreambooth/train_dreambooth_lora_sana.py | 1552 +++++++++++++++++ src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 308 ++++ src/diffusers/loaders/peft.py | 1 + .../models/transformers/sana_transformer.py | 26 +- .../pipelines/pag/pipeline_pag_sana.py | 1 - src/diffusers/pipelines/sana/pipeline_sana.py | 36 +- tests/lora/test_lora_layers_sana.py | 138 ++ tests/lora/utils.py | 7 +- 11 files changed, 2200 insertions(+), 6 deletions(-) create mode 100644 examples/dreambooth/REAMDE_sana.md create mode 100644 examples/dreambooth/requirements_sana.txt create mode 100644 examples/dreambooth/train_dreambooth_lora_sana.py create mode 100644 tests/lora/test_lora_layers_sana.py diff --git a/examples/dreambooth/REAMDE_sana.md b/examples/dreambooth/REAMDE_sana.md new file mode 100644 index 000000000000..fe861d62472b --- /dev/null +++ b/examples/dreambooth/REAMDE_sana.md @@ -0,0 +1,127 @@ +# DreamBooth training example for SANA + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject. + +The `train_dreambooth_lora_sana.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [SANA](https://arxiv.org/abs/2410.10629). + + +This will also allow us to push the trained model parameters to the Hugging Face Hub platform. + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/dreambooth` folder and run +```bash +pip install -r requirements_sana.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell (e.g., a notebook) + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment. + + +### Dog toy example + +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./dog" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. + +Now, we can launch training using: + +```bash +export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-sana-lora" + +accelerate launch train_dreambooth_lora_sana.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="bf16" \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --use_8bit_adam \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +For using `push_to_hub`, make you're logged into your Hugging Face account: + +```bash +huggingface-cli login +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +## Notes + +Additionally, we welcome you to explore the following CLI arguments: + +* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only. +* `--complex_human_instruction`: Instructions for complex human attention as shown in [here](https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55). +* `--max_sequence_length`: Maximum sequence length to use for text embeddings. + + +We provide several options for optimizing memory optimization: + +* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used. +* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done. +* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. + +Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference. \ No newline at end of file diff --git a/examples/dreambooth/requirements_sana.txt b/examples/dreambooth/requirements_sana.txt new file mode 100644 index 000000000000..04b4bd6c29c0 --- /dev/null +++ b/examples/dreambooth/requirements_sana.txt @@ -0,0 +1,8 @@ +accelerate>=1.0.0 +torchvision +transformers>=4.47.0 +ftfy +tensorboard +Jinja2 +peft>=0.14.0 +sentencepiece \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py new file mode 100644 index 000000000000..4baa9f194feb --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -0,0 +1,1552 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import itertools +import logging +import math +import os +import random +import shutil +import warnings +from pathlib import Path + +import numpy as np +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import AutoTokenizer, Gemma2Model + +import diffusers +from diffusers import ( + AutoencoderDC, + FlowMatchEulerDiscreteScheduler, + SanaPipeline, + SanaTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.32.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Sana DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Sana diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md). + + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +TODO +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +TODO +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "sana", + "sana-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return images + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=300, + help="Maximum sequence length to use with with the Gemma model", + ) + parser.add_argument( + "--complex_human_instruction", + type=str, + default=None, + help="Instructions for complex human attention: https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sana-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + self.pixel_values.append(image) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + pipeline = SanaPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch.float32, + revision=args.revision, + variant=args.variant, + ) + pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16) + pipeline.transformer = pipeline.transformer.to(torch.float16) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizer + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + text_encoder = Gemma2Model.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae = AutoencoderDC.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + transformer = SanaTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + ) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = SanaPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + # VAE should always be kept in fp32 for SANA (?) + vae.to(dtype=torch.float32) + transformer.to(accelerator.device, dtype=weight_dtype) + # because Gemma2 is particularly suited for bfloat16. + text_encoder.to(dtype=torch.bfloat16) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = ["to_k", "to_q", "to_v"] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + SanaPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict = SanaPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + def compute_text_embeddings(prompt, text_encoding_pipeline): + text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) + with torch.no_grad(): + prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt( + prompt, + max_sequence_length=args.max_sequence_length, + complex_human_instruction=args.complex_human_instruction, + ) + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + return prompt_embeds, prompt_attention_mask + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_prompt_attention_mask = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + class_prompt_hidden_states, class_prompt_attention_mask = compute_text_embeddings( + args.class_prompt, text_encoding_pipeline + ) + + # Clear the memory here + if not train_dataset.custom_instance_prompts: + del text_encoder, tokenizer + free_memory() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + if not train_dataset.custom_instance_prompts: + prompt_embeds = instance_prompt_hidden_states + prompt_attention_mask = instance_prompt_attention_mask + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + prompt_attention_mask = torch.cat([prompt_attention_mask, class_prompt_attention_mask], dim=0) + + vae_config_scaling_factor = vae.config.scaling_factor + if args.cache_latents: + latents_cache = [] + vae = vae.to("cuda") + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent) + + if args.validation_prompt is None: + del vae + free_memory() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-sana-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + with accelerator.accumulate(models_to_accumulate): + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + prompt_embeds, prompt_attention_mask = compute_text_embeddings(prompts, text_encoding_pipeline) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step] + else: + vae = vae.to(accelerator.device) + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent + if args.offload: + vae = vae.to("cpu") + model_input = model_input * vae_config_scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # Predict the noise residual + model_pred = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timesteps, + return_dict=False, + )[0] + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + pipeline = SanaPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=accelerator.unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=torch.float32, + ) + pipeline_args = { + "prompt": args.validation_prompt, + "complex_human_instruction": args.complex_human_instruction, + } + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + free_memory() + + images = None + del pipeline + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + SanaPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + ) + + # Final inference + # Load previous pipeline + pipeline = SanaPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=torch.float32, + ) + pipeline.transformer = pipeline.transformer.to(torch.float16) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline_args = { + "prompt": args.validation_prompt, + "complex_human_instruction": args.complex_human_instruction, + } + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + images = None + del pipeline + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index d59830e614e9..b59150376599 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -70,6 +70,7 @@ def text_encoder_attn_modules(text_encoder): "FluxLoraLoaderMixin", "CogVideoXLoraLoaderMixin", "Mochi1LoraLoaderMixin", + "SanaLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"] @@ -92,6 +93,7 @@ def text_encoder_attn_modules(text_encoder): LoraLoaderMixin, LTXVideoLoraLoaderMixin, Mochi1LoraLoaderMixin, + SanaLoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 869a5cca24f5..b8c44e480093 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3562,6 +3562,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) +class SanaLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SanaTransformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components) + + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3dddb94f30c1..a791a250af08 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -54,6 +54,7 @@ "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, "MochiTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, + "SanaTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index dba67f45fce9..41224e42d2a5 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -18,7 +18,8 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, AttentionProcessor, @@ -180,7 +181,7 @@ def forward( return hidden_states -class SanaTransformer2DModel(ModelMixin, ConfigMixin): +class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. @@ -363,8 +364,24 @@ def forward( timestep: torch.LongTensor, encoder_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -460,6 +477,11 @@ def custom_forward(*inputs): hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4) output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index c6e7554e6b69..cf4d41fee487 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -170,7 +170,6 @@ def __init__( pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()), ) - # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 80736d498e0f..2df6586d0bc4 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -16,21 +16,25 @@ import inspect import re import urllib.parse as ul -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import AutoModelForCausalLM, AutoTokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin from ...models import AutoencoderDC, SanaTransformer2DModel from ...schedulers import DPMSolverMultistepScheduler from ...utils import ( BACKENDS_MAPPING, + USE_PEFT_BACKEND, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -130,7 +134,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class SanaPipeline(DiffusionPipeline): +class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): r""" Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). """ @@ -177,6 +181,7 @@ def encode_prompt( clean_caption: bool = False, max_sequence_length: int = 300, complex_human_instruction: Optional[List[str]] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -210,6 +215,15 @@ def encode_prompt( if device is None: device = self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -305,6 +319,11 @@ def encode_prompt( negative_prompt_embeds = None negative_prompt_attention_mask = None + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -554,6 +573,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def guidance_scale(self): return self._guidance_scale + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def do_classifier_free_guidance(self): return self._guidance_scale > 1.0 @@ -590,6 +613,7 @@ def __call__( return_dict: bool = True, clean_caption: bool = True, use_resolution_binning: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 300, @@ -662,6 +686,10 @@ def __call__( [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw @@ -722,6 +750,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Default height and width to transformer @@ -733,6 +762,7 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None # 3. Encode input prompt ( @@ -753,6 +783,7 @@ def __call__( clean_caption=clean_caption, max_sequence_length=max_sequence_length, complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) @@ -801,6 +832,7 @@ def __call__( encoder_attention_mask=prompt_attention_mask, timestep=timestep, return_dict=False, + attention_kwargs=self.attention_kwargs, )[0] noise_pred = noise_pred.float() diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py new file mode 100644 index 000000000000..499ca89262a0 --- /dev/null +++ b/tests/lora/test_lora_layers_sana.py @@ -0,0 +1,138 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +import torch +from transformers import Gemma2ForCausalLM, GemmaTokenizer + +from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend + + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = SanaPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0) + scheduler_kwargs = {} + scheduler_classes = [FlowMatchEulerDiscreteScheduler] + transformer_kwargs = { + "patch_size": 1, + "in_channels": 4, + "out_channels": 4, + "num_layers": 1, + "num_attention_heads": 2, + "attention_head_dim": 4, + "num_cross_attention_heads": 2, + "cross_attention_head_dim": 4, + "cross_attention_dim": 8, + "caption_channels": 8, + "sample_size": 32, + } + transformer_cls = SanaTransformer2DModel + vae_kwargs = { + "in_channels": 3, + "latent_channels": 4, + "attention_head_dim": 2, + "encoder_block_types": ( + "ResBlock", + "EfficientViTBlock", + ), + "decoder_block_types": ( + "ResBlock", + "EfficientViTBlock", + ), + "encoder_block_out_channels": (8, 8), + "decoder_block_out_channels": (8, 8), + "encoder_qkv_multiscales": ((), (5,)), + "decoder_qkv_multiscales": ((), (5,)), + "encoder_layers_per_block": (1, 1), + "decoder_layers_per_block": [1, 1], + "downsample_block_type": "conv", + "upsample_block_type": "interpolate", + "decoder_norm_types": "rms_norm", + "decoder_act_fns": "silu", + "scaling_factor": 0.41407, + } + vae_cls = AutoencoderDC + tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma" + text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers" + + @property + def output_shape(self): + return (1, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "", + "negative_prompt": "", + "num_inference_steps": 4, + "guidance_scale": 4.5, + "height": 32, + "width": 32, + "max_sequence_length": sequence_length, + "output_type": "np", + "complex_human_instruction": None, + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + @unittest.skip("Not supported in Sana.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Not supported in Mochi.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in Mochi.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Mochi.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Mochi.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Mochi.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Mochi.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Mochi.") + def test_simple_inference_with_text_lora_save_load(self): + pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 990cf71f298e..ac7a944cd026 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1545,7 +1545,12 @@ def test_lora_fuse_nan(self): "adapter-1" ].weight += float("inf") else: - pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") + named_modules = [name for name, _ in pipe.transformer.named_modules()] + has_attn1 = any("attn1" in name for name in named_modules) + if has_attn1: + pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") + else: + pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") # with `safe_fusing=True` we should see an Error with self.assertRaises(ValueError): From ba6fd6eb30de97370f06f5804d9cc0e10b5718b5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 18 Dec 2024 08:43:57 +0530 Subject: [PATCH 31/88] [chore] fix: licensing headers in mochi and ltx (#10275) fix: licensing header. --- src/diffusers/pipelines/ltx/pipeline_ltx.py | 2 +- src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py | 2 +- src/diffusers/pipelines/mochi/pipeline_mochi.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 543af08f2e3c..7180601dad41 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -1,4 +1,4 @@ -# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 6d2afc56ed39..fbb30e304d65 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -1,4 +1,4 @@ -# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index dfc0a9be278d..937575d26f98 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -1,4 +1,4 @@ -# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# Copyright 2024 Genmo and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 0ac52d6f0970d5d91a1c88d4bf2e297d9298c642 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 18 Dec 2024 04:26:52 +0000 Subject: [PATCH 32/88] Use `torch` in `get_2d_rotary_pos_embed` (#10155) * Use `torch` in `get_2d_rotary_pos_embed` * Add deprecation --- ...ipeline_hunyuandit_differential_img2img.py | 2 + src/diffusers/models/embeddings.py | 52 ++++++++++++++++++- .../pipeline_hunyuandit_controlnet.py | 6 ++- .../hunyuandit/pipeline_hunyuandit.py | 6 ++- .../pipelines/pag/pipeline_pag_hunyuandit.py | 6 ++- 5 files changed, 68 insertions(+), 4 deletions(-) diff --git a/examples/community/pipeline_hunyuandit_differential_img2img.py b/examples/community/pipeline_hunyuandit_differential_img2img.py index 3ece670e5bde..8cf2830f25ab 100644 --- a/examples/community/pipeline_hunyuandit_differential_img2img.py +++ b/examples/community/pipeline_hunyuandit_differential_img2img.py @@ -1008,6 +1008,8 @@ def __call__( self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width), + device=device, + output_type="pt", ) style = torch.tensor([0], device=device) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0f4b555a2d71..f3c57103f9b8 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro( return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w -def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): +def get_2d_rotary_pos_embed( + embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np" +): + """ + RoPE for image tokens with 2d structure. + + Args: + embed_dim: (`int`): + The embedding dimension size + crops_coords (`Tuple[int]`) + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + device: (`torch.device`, **optional**): + The device used to create tensors. + + Returns: + `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. + """ + if output_type == "np": + deprecation_message = ( + "`get_2d_sincos_pos_embed` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return _get_2d_rotary_pos_embed_np( + embed_dim=embed_dim, + crops_coords=crops_coords, + grid_size=grid_size, + use_real=use_real, + ) + start, stop = crops_coords + # scale end by (steps−1)/steps matches np.linspace(..., endpoint=False) + grid_h = torch.linspace( + start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32 + ) + grid_w = torch.linspace( + start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32 + ) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py index 45e17f3de1e2..c8464f8108ea 100644 --- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -925,7 +925,11 @@ def __call__( base_size = 512 // 8 // self.transformer.config.patch_size grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) image_rotary_emb = get_2d_rotary_pos_embed( - self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", ) style = torch.tensor([0], device=device) diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index bda718cb197d..6f542cb59f46 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -798,7 +798,11 @@ def __call__( base_size = 512 // 8 // self.transformer.config.patch_size grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) image_rotary_emb = get_2d_rotary_pos_embed( - self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", ) style = torch.tensor([0], device=device) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index 408992378538..dea1f12696b2 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -818,7 +818,11 @@ def __call__( base_size = 512 // 8 // self.transformer.config.patch_size grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) image_rotary_emb = get_2d_rotary_pos_embed( - self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", ) style = torch.tensor([0], device=device) From 63cdf9c0ba20d11f30c07c6b73a3e80ae9eb99dd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 18 Dec 2024 10:56:08 +0530 Subject: [PATCH 33/88] [chore] fix: reamde -> readme (#10276) fix: reamde -> readme --- examples/dreambooth/{REAMDE_sana.md => README_sana.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/dreambooth/{REAMDE_sana.md => README_sana.md} (100%) diff --git a/examples/dreambooth/REAMDE_sana.md b/examples/dreambooth/README_sana.md similarity index 100% rename from examples/dreambooth/REAMDE_sana.md rename to examples/dreambooth/README_sana.md From 88b015dc9fdda01e0de44fcc2c1f719f6531c811 Mon Sep 17 00:00:00 2001 From: Xinyuan Zhao <22809191+Bichidian@users.noreply.github.com> Date: Wed, 18 Dec 2024 15:55:18 +0800 Subject: [PATCH 34/88] Make `time_embed_dim` of `UNet2DModel` changeable (#10262) --- src/diffusers/models/unets/unet_2d.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 5972505f2897..d05af686dede 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -97,6 +97,7 @@ def __init__( out_channels: int = 3, center_input_sample: bool = False, time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, freq_shift: int = 0, flip_sin_to_cos: bool = True, down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), @@ -122,7 +123,7 @@ def __init__( super().__init__() self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 # Check inputs if len(down_block_types) != len(up_block_types): From 8eb73c872afbe59abab4580aaa591a9851a42e6d Mon Sep 17 00:00:00 2001 From: Qin Zhou <1079207272@qq.com> Date: Wed, 18 Dec 2024 15:58:33 +0800 Subject: [PATCH 35/88] Support pass kwargs to sd3 custom attention processor (#9818) * Support pass kwargs to sd3 custom attention processor --------- Co-authored-by: hlky Co-authored-by: YiYi Xu --- src/diffusers/models/attention.py | 13 ++++++++++--- .../models/transformers/transformer_sd3.py | 6 +++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6749c7f17254..4d1dae879f11 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -188,8 +188,13 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): self._chunk_dim = dim def forward( - self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ): + joint_attention_kwargs = joint_attention_kwargs or {} if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( hidden_states, emb=temb @@ -206,7 +211,9 @@ def forward( # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + **joint_attention_kwargs, ) # Process attention outputs for the `hidden_states`. @@ -214,7 +221,7 @@ def forward( hidden_states = hidden_states + attn_output if self.use_dual_attention: - attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 hidden_states = hidden_states + attn_output2 diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 79452bb85176..79c4069e9a37 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -411,11 +411,15 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, temb, + joint_attention_kwargs, **ckpt_kwargs, ) elif not is_skip: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From 83709d5a06b48decee05e434c272d738c2248c16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Romero?= Date: Wed, 18 Dec 2024 04:14:16 -0500 Subject: [PATCH 36/88] Flux Control(Depth/Canny) + Inpaint (#10192) * flux_control_inpaint - failing test_flux_different_prompts * removing test_flux_different_prompts? * fix style * fix from PR comments * fix style * reducing guidance_scale in demo * Update src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py Co-authored-by: hlky * make * prepare_latents is not copied from * update docs * typos --------- Co-authored-by: affromero Co-authored-by: Sayak Paul Co-authored-by: hlky --- docs/source/en/_toctree.yml | 2 + .../en/api/pipelines/control_flux_inpaint.md | 89 ++ src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/flux/__init__.py | 2 + .../flux/pipeline_flux_control_inpaint.py | 1141 +++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + .../test_pipeline_flux_control_inpaint.py | 215 ++++ 8 files changed, 1468 insertions(+) create mode 100644 docs/source/en/api/pipelines/control_flux_inpaint.md create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py create mode 100644 tests/pipelines/flux/test_pipeline_flux_control_inpaint.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ab733054fbd3..27e9fe5e191b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -400,6 +400,8 @@ title: DiT - local: api/pipelines/flux title: Flux + - local: api/pipelines/control_flux_inpaint + title: FluxControlInpaint - local: api/pipelines/hunyuandit title: Hunyuan-DiT - local: api/pipelines/hunyuan_video diff --git a/docs/source/en/api/pipelines/control_flux_inpaint.md b/docs/source/en/api/pipelines/control_flux_inpaint.md new file mode 100644 index 000000000000..0cf4f4b4225e --- /dev/null +++ b/docs/source/en/api/pipelines/control_flux_inpaint.md @@ -0,0 +1,89 @@ + + +# FluxControlInpaint + +FluxControlInpaintPipeline is an implementation of Inpainting for Flux.1 Depth/Canny models. It is a pipeline that allows you to inpaint images using the Flux.1 Depth/Canny models. The pipeline takes an image and a mask as input and returns the inpainted image. + +FLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transformer capable of generating an image based on a text description while following the structure of a given input image. **This is not a ControlNet model**. + +| Control type | Developer | Link | +| -------- | ---------- | ---- | +| Depth | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) | +| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) | + + + + +Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c). + + + +```python +import torch +from diffusers import FluxControlInpaintPipeline +from diffusers.models.transformers import FluxTransformer2DModel +from transformers import T5EncoderModel +from diffusers.utils import load_image, make_image_grid +from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux +from PIL import Image +import numpy as np + +pipe = FluxControlInpaintPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Depth-dev", + torch_dtype=torch.bfloat16, +) +# use following lines if you have GPU constraints +# --------------------------------------------------------------- +transformer = FluxTransformer2DModel.from_pretrained( + "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16 +) +text_encoder_2 = T5EncoderModel.from_pretrained( + "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16 +) +pipe.transformer = transformer +pipe.text_encoder_2 = text_encoder_2 +pipe.enable_model_cpu_offload() +# --------------------------------------------------------------- +pipe.to("cuda") + +prompt = "a blue robot singing opera with human-like expressions" +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +head_mask = np.zeros_like(image) +head_mask[65:580,300:642] = 255 +mask_image = Image.fromarray(head_mask) + +processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") +control_image = processor(image)[0].convert("RGB") + +output = pipe( + prompt=prompt, + image=image, + control_image=control_image, + mask_image=mask_image, + num_inference_steps=30, + strength=0.9, + guidance_scale=10.0, + generator=torch.Generator().manual_seed(42), +).images[0] +make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save("output.png") +``` + +## FluxControlInpaintPipeline +[[autodoc]] FluxControlInpaintPipeline + - all + - __call__ + + +## FluxPipelineOutput +[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e2351a0c53b8..91b297f8c007 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -277,6 +277,7 @@ "CogView3PlusPipeline", "CycleDiffusionPipeline", "FluxControlImg2ImgPipeline", + "FluxControlInpaintPipeline", "FluxControlNetImg2ImgPipeline", "FluxControlNetInpaintPipeline", "FluxControlNetPipeline", @@ -765,6 +766,7 @@ CogView3PlusPipeline, CycleDiffusionPipeline, FluxControlImg2ImgPipeline, + FluxControlInpaintPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e7fd7ec78bed..ce291e5ceb45 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -128,6 +128,7 @@ ] _import_structure["flux"] = [ "FluxControlPipeline", + "FluxControlInpaintPipeline", "FluxControlImg2ImgPipeline", "FluxControlNetPipeline", "FluxControlNetImg2ImgPipeline", @@ -539,6 +540,7 @@ ) from .flux import ( FluxControlImg2ImgPipeline, + FluxControlInpaintPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 3570368a5ca1..72e1b578f2ca 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -26,6 +26,7 @@ _import_structure["pipeline_flux"] = ["FluxPipeline"] _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"] _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"] + _import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] @@ -44,6 +45,7 @@ from .pipeline_flux import FluxPipeline from .pipeline_flux_control import FluxControlPipeline from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline + from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline from .pipeline_flux_controlnet import FluxControlNetPipeline from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py new file mode 100644 index 000000000000..a9ac1c72c6ed --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -0,0 +1,1141 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +) +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + import torch + from diffusers import FluxControlInpaintPipeline + from diffusers.models.transformers import FluxTransformer2DModel + from transformers import T5EncoderModel + from diffusers.utils import load_image, make_image_grid + from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux + from PIL import Image + import numpy as np + + pipe = FluxControlInpaintPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Depth-dev", + torch_dtype=torch.bfloat16, + ) + # use following lines if you have GPU constraints + # --------------------------------------------------------------- + transformer = FluxTransformer2DModel.from_pretrained( + "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16 + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16 + ) + pipe.transformer = transformer + pipe.text_encoder_2 = text_encoder_2 + pipe.enable_model_cpu_offload() + # --------------------------------------------------------------- + pipe.to("cuda") + + prompt = "a blue robot singing opera with human-like expressions" + image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + + head_mask = np.zeros_like(image) + head_mask[65:580, 300:642] = 255 + mask_image = Image.fromarray(head_mask) + + processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") + control_image = processor(image)[0].convert("RGB") + + output = pipe( + prompt=prompt, + image=image, + control_image=control_image, + mask_image=mask_image, + num_inference_steps=30, + strength=0.9, + guidance_scale=10.0, + generator=torch.Generator().manual_seed(42), + ).images[0] + make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save( + "output.png" + ) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlInpaintPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux pipeline for image inpainting using Flux-dev-Depth/Canny. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_mask_latents( + self, + image, + mask_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + image = self.image_processor.preprocess(image, height=height, width=width) + mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) + + masked_image = image * (1 - mask_image) + masked_image = masked_image.to(device=device, dtype=dtype) + + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask_image = torch.nn.functional.interpolate(mask_image, size=(height, width)) + mask_image = mask_image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == num_channels_latents: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask_image.shape[0] < batch_size: + if not batch_size % mask_image.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask_image.shape[0]} mask_image were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask_image = self._pack_latents( + mask_image.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + masked_image_latents = torch.cat((masked_image_latents, mask_image), dim=-1) + + return mask_image, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + device = self._execution_device + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 3. Preprocess mask and image + num_channels_latents = self.vae.config.latent_channels + if masked_image_latents is not None: + # pre computed masked_image_latents and mask_image + masked_image_latents = masked_image_latents.to(latents.device) + mask = mask_image.to(latents.device) + else: + mask, masked_image_latents = self.prepare_mask_latents( + image, + mask_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 8 + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + if control_image.ndim == 4: + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height_8 = 2 * (int(height) // (self.vae_scale_factor * 2)) + width_8 = 2 * (int(width) // (self.vae_scale_factor * 2)) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents, control_image], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_mask = mask + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + image_latents, torch.tensor([noise_timestep]), noise + ) + else: + init_latents_proper = image_latents + init_latents_proper = self._pack_latents( + init_latents_proper, batch_size * num_images_per_prompt, num_channels_latents, height_8, width_8 + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e148c025d191..9b36be9e0604 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -392,6 +392,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxControlInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxControlNetImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py new file mode 100644 index 000000000000..c5ff02a525f2 --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py @@ -0,0 +1,215 @@ +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxControlInpaintPipeline, + FluxTransformer2DModel, +) +from diffusers.utils.testing_utils import ( + torch_device, +) + +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) + + +class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxControlInpaintPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + # there is no xformers processor for Flux + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=8, + out_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + image = Image.new("RGB", (8, 8), 0) + control_image = Image.new("RGB", (8, 8), 0) + mask_image = Image.new("RGB", (8, 8), 255) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "control_image": control_image, + "generator": generator, + "image": image, + "mask_image": mask_image, + "strength": 0.8, + "num_inference_steps": 2, + "guidance_scale": 30.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "output_type": "np", + } + return inputs + + # def test_flux_different_prompts(self): + # pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + # inputs = self.get_dummy_inputs(torch_device) + # output_same_prompt = pipe(**inputs).images[0] + + # inputs = self.get_dummy_inputs(torch_device) + # inputs["prompt_2"] = "a different prompt" + # output_different_prompts = pipe(**inputs).images[0] + + # max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # # Outputs should be different here + # # For some reasons, they don't show large differences + # assert max_diff > 1e-6 + + def test_flux_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) From e222246b4e7b60db7fe5fd27dc187bce446b5b56 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 18 Dec 2024 12:22:10 +0000 Subject: [PATCH 37/88] Fix sigma_last with use_flow_sigmas (#10267) --- src/diffusers/schedulers/scheduling_deis_multistep.py | 1 + .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 3 +++ src/diffusers/schedulers/scheduling_sasolver.py | 1 + src/diffusers/schedulers/scheduling_unipc_multistep.py | 9 +++++++++ 4 files changed, 14 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 3350c3373ecf..6a653f183bba 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -289,6 +289,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = 1.0 - alphas sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 19399a724a41..971817f7b777 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -291,14 +291,17 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) elif self.config.use_beta_sigmas: sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) elif self.config.use_flow_sigmas: alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) sigmas = 1.0 - alphas sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_max = ( diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 41a471275fa2..d45c93880bc5 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -318,6 +318,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = 1.0 - alphas sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index c6434c6f87c6..01500426305c 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -381,6 +381,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = 1.0 - alphas sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min": From b389f339ec016cb83f0975c1c9cc0d7965e411f8 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 18 Dec 2024 18:32:36 +0530 Subject: [PATCH 38/88] Fix Doc links in GGUF and Quantization overview docs (#10279) * update * Update docs/source/en/quantization/gguf.md Co-authored-by: Aryan --------- Co-authored-by: Aryan --- docs/source/en/quantization/gguf.md | 4 ++-- docs/source/en/quantization/overview.md | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md index dbcd1b1486b2..2ff2a9293130 100644 --- a/docs/source/en/quantization/gguf.md +++ b/docs/source/en/quantization/gguf.md @@ -25,9 +25,9 @@ pip install -U gguf Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`]. -When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.unint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`. +When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.uint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`. -The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original (`numpy`)[https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py] implementation by [compilade](https://github.com/compilade). +The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original [`numpy`](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py) implementation by [compilade](https://github.com/compilade). ```python import torch diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 6c2df7514d5e..3eef5238f1ce 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -33,8 +33,8 @@ If you are new to the quantization field, we recommend you to check out these be ## When to use what? Diffusers currently supports the following quantization methods. -- [BitsandBytes]() -- [TorchAO]() -- [GGUF]() +- [BitsandBytes](./bitsandbytes.md) +- [TorchAO](./torchao.md) +- [GGUF](./gguf.md) [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. From 8304adce2aa171f0328c882001ba76891ee661d2 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 18 Dec 2024 18:32:53 +0530 Subject: [PATCH 39/88] Make zeroing prompt embeds for Mochi Pipeline configurable (#10284) update --- src/diffusers/pipelines/mochi/pipeline_mochi.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 937575d26f98..aac4e32e33f0 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -188,6 +188,7 @@ def __init__( text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, transformer: MochiTransformer3DModel, + force_zeros_for_empty_prompt: bool = False, ): super().__init__() @@ -205,10 +206,11 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256 ) self.default_height = 480 self.default_width = 848 + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) def _get_t5_prompt_embeds( self, @@ -236,7 +238,11 @@ def _get_t5_prompt_embeds( text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.bool().to(device) - if prompt == "" or prompt[-1] == "": + + # The original Mochi implementation zeros out empty negative prompts + # but this can lead to overflow when placing the entire pipeline under the autocast context + # adding this here so that we can enable zeroing prompts if necessary + if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""): text_input_ids = torch.zeros_like(text_input_ids, device=device) prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device) From 862a7d5038c1c53641ffcab146a7eeb5ab683656 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 18 Dec 2024 19:19:47 +0530 Subject: [PATCH 40/88] [Single File] Add single file support for Flux Canny, Depth and Fill (#10288) update --- src/diffusers/loaders/single_file_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 4e288737fe88..ded466b35e9a 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -151,6 +151,8 @@ "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, + "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, + "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, "ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"}, "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, @@ -587,7 +589,13 @@ def infer_diffusers_model_type(checkpoint): if any( g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] ): - model_type = "flux-dev" + if checkpoint["img_in.weight"].shape[1] == 384: + model_type = "flux-fill" + + elif checkpoint["img_in.weight"].shape[1] == 128: + model_type = "flux-depth" + else: + model_type = "flux-dev" else: model_type = "flux-schnell" From c4c99c3907c9524dc15e86ddd69389a5ffcdc07d Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 18 Dec 2024 22:36:08 +0530 Subject: [PATCH 41/88] [tests] Fix broken cuda, nightly and lora tests on main for CogVideoX (#10270) fix joint pos embedding device --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index f3c57103f9b8..69b3ee8466f4 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -691,7 +691,7 @@ def _get_positional_embeddings( output_type="pt", ) pos_embedding = pos_embedding.flatten(0, 1) - joint_pos_embedding = torch.zeros( + joint_pos_embedding = pos_embedding.new_zeros( 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False ) joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) From f66bd3261c29c41202505673738c905119d1b066 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 18 Dec 2024 22:41:23 +0530 Subject: [PATCH 42/88] Rename Mochi integration test correctly (#10220) rename integration test --- tests/pipelines/mochi/test_mochi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index 2192c171aa22..bbcf6d210ce5 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -275,7 +275,7 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def test_cogvideox(self): + def test_mochi(self): generator = torch.Generator("cpu").manual_seed(0) pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.float16) From f35a38725b4d263330a591dc7bdb54b002b96675 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 19 Dec 2024 01:19:08 +0530 Subject: [PATCH 43/88] [tests] remove nullop import checks from lora tests (#10273) remove nullop imports --- tests/lora/test_lora_layers_cogvideox.py | 4 ---- tests/lora/test_lora_layers_mochi.py | 4 ---- tests/lora/test_lora_layers_sd3.py | 4 ---- 3 files changed, 12 deletions(-) diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index 15f8ebf4505c..aa7a1619a183 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -29,7 +29,6 @@ ) from diffusers.utils.testing_utils import ( floats_tensor, - is_peft_available, is_torch_version, require_peft_backend, skip_mps, @@ -37,9 +36,6 @@ ) -if is_peft_available(): - pass - sys.path.append(".") from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index 0a07e3d096bb..4bfc5a824d43 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -23,7 +23,6 @@ from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel from diffusers.utils.testing_utils import ( floats_tensor, - is_peft_available, is_torch_version, require_peft_backend, skip_mps, @@ -31,9 +30,6 @@ ) -if is_peft_available(): - pass - sys.path.append(".") from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index b37a2a297e04..8c42f9c86ee9 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -29,7 +29,6 @@ from diffusers.utils import load_image from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.testing_utils import ( - is_peft_available, numpy_cosine_similarity_distance, require_peft_backend, require_torch_gpu, @@ -37,9 +36,6 @@ ) -if is_peft_available(): - pass - sys.path.append(".") from utils import PeftLoraLoaderMixinTests # noqa: E402 From 9c0e20de61a6e0adcec706564cee739520c1d2f4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 19 Dec 2024 10:24:57 +0530 Subject: [PATCH 44/88] [chore] Update README_sana.md to update the default model (#10285) Update README_sana.md to update the default model --- examples/dreambooth/README_sana.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/README_sana.md b/examples/dreambooth/README_sana.md index fe861d62472b..d82529c64de8 100644 --- a/examples/dreambooth/README_sana.md +++ b/examples/dreambooth/README_sana.md @@ -73,7 +73,7 @@ This will also allow us to push the trained LoRA parameters to the Hugging Face Now, we can launch training using: ```bash -export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers" +export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers" export INSTANCE_DIR="dog" export OUTPUT_DIR="trained-sana-lora" @@ -124,4 +124,4 @@ We provide several options for optimizing memory optimization: * `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done. * `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. -Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference. \ No newline at end of file +Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference. From f781b8c30c4d70fbf0afcc9799c7f9e9693b2921 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 19 Dec 2024 10:28:10 +0530 Subject: [PATCH 45/88] Hunyuan VAE tiling fixes and transformer docs (#10295) * update * udpate * fix test --- .../autoencoder_kl_hunyuan_video.py | 8 ++-- .../transformers/transformer_hunyuan_video.py | 40 +++++++++++++++++++ .../test_models_autoencoder_hunyuan_video.py | 25 ++++++++++++ 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index bded90a8bcff..5c1d94d4e18f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -792,12 +792,12 @@ def __init__( # The minimal tile height and width for spatial tiling to be used self.tile_sample_min_height = 256 self.tile_sample_min_width = 256 - self.tile_sample_min_num_frames = 64 + self.tile_sample_min_num_frames = 16 # The minimal distance between two spatial tiles self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 - self.tile_sample_stride_num_frames = 48 + self.tile_sample_stride_num_frames = 12 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)): @@ -1003,7 +1003,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: for i in range(0, height, self.tile_sample_stride_height): row = [] for j in range(0, width, self.tile_sample_stride_width): - tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] tile = self.encoder(tile) tile = self.quant_conv(tile) row.append(tile) @@ -1020,7 +1020,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: if j > 0: tile = self.blend_h(row[j - 1], tile, blend_width) result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) - result_rows.append(torch.cat(result_row, dim=-1)) + result_rows.append(torch.cat(result_row, dim=4)) enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] return enc diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index d8f9834ea61c..737be99c5a10 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -497,6 +497,46 @@ def forward( class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): + r""" + A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). + + Args: + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `20`): + The number of layers of dual-stream blocks to use. + num_single_layers (`int`, defaults to `40`): + The number of layers of single-stream blocks to use. + num_refiner_layers (`int`, defaults to `2`): + The number of layers of refiner blocks to use. + mlp_ratio (`float`, defaults to `4.0`): + The ratio of the hidden layer size to the input size in the feedforward network. + patch_size (`int`, defaults to `2`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + qk_norm (`str`, defaults to `rms_norm`): + The normalization to use for the query and key projections in the attention layers. + guidance_embeds (`bool`, defaults to `True`): + Whether to use guidance embeddings in the model. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + pooled_projection_dim (`int`, defaults to `768`): + The dimension of the pooled projection of the text embeddings. + rope_theta (`float`, defaults to `256.0`): + The value of theta to use in the RoPE layer. + rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions of the axes to use in the RoPE layer. + """ + + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py index 826ac30d5f2f..7b7901a6fd94 100644 --- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -43,10 +43,14 @@ def get_autoencoder_kl_hunyuan_video_config(self): "down_block_types": ( "HunyuanVideoDownBlock3D", "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", ), "up_block_types": ( "HunyuanVideoUpBlock3D", "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", ), "block_out_channels": (8, 8, 8, 8), "layers_per_block": 1, @@ -154,6 +158,27 @@ def test_gradient_checkpointing_is_applied(self): } super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + # We need to overwrite this test because the base test does not account length of down_block_types + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 16, 16, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + @unittest.skip("Unsupported test.") def test_outputs_equivalence(self): pass From 4450d26b63b4f6e7736ca86f11d0c37827159bfa Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 08:28:56 +0000 Subject: [PATCH 46/88] Add Flux Control to AutoPipeline (#10292) --- src/diffusers/pipelines/auto_pipeline.py | 37 ++++++++++++++++++++---- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index a0f95fe6cdc1..f3a05c2c661f 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -35,9 +35,12 @@ ) from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .flux import ( + FluxControlImg2ImgPipeline, + FluxControlInpaintPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, + FluxControlPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline, @@ -125,6 +128,7 @@ ("pixart-sigma-pag", PixArtSigmaPAGPipeline), ("auraflow", AuraFlowPipeline), ("flux", FluxPipeline), + ("flux-control", FluxControlPipeline), ("flux-controlnet", FluxControlNetPipeline), ("lumina", LuminaText2ImgPipeline), ("cogview3", CogView3PlusPipeline), @@ -150,6 +154,7 @@ ("lcm", LatentConsistencyModelImg2ImgPipeline), ("flux", FluxImg2ImgPipeline), ("flux-controlnet", FluxControlNetImg2ImgPipeline), + ("flux-control", FluxControlImg2ImgPipeline), ] ) @@ -168,6 +173,7 @@ ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("flux", FluxInpaintPipeline), ("flux-controlnet", FluxControlNetInpaintPipeline), + ("flux-control", FluxControlInpaintPipeline), ("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline), ] ) @@ -401,16 +407,20 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) orig_class_name = config["_class_name"] + if "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" if "controlnet" in kwargs: if isinstance(kwargs["controlnet"], ControlNetUnionModel): - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline") + orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline") else: - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline") if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: - orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") + orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline") text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) @@ -694,8 +704,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): # the `orig_class_name` can be: # `- *Pipeline` (for regular text-to-image checkpoint) + # - `*ControlPipeline` (for Flux tools specific checkpoint) # `- *Img2ImgPipeline` (for refiner checkpoint) - to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" + if "Img2Img" in orig_class_name: + to_replace = "Img2ImgPipeline" + elif "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" if "controlnet" in kwargs: if isinstance(kwargs["controlnet"], ControlNetUnionModel): @@ -707,6 +723,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): if enable_pag: orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + if to_replace == "ControlPipeline": + orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline") + image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs} @@ -994,8 +1013,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): # The `orig_class_name`` can be: # `- *InpaintPipeline` (for inpaint-specific checkpoint) + # - `*ControlPipeline` (for Flux tools specific checkpoint) # - or *Pipeline (for regular text-to-image checkpoint) - to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" + if "Inpaint" in orig_class_name: + to_replace = "InpaintPipeline" + elif "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" if "controlnet" in kwargs: if isinstance(kwargs["controlnet"], ControlNetUnionModel): @@ -1006,6 +1031,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): enable_pag = kwargs.pop("enable_pag") if enable_pag: orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + if to_replace == "ControlPipeline": + orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline") inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs} From 2f7a417d1fb11bd242ad7f9098bb9fdf77c54422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E4=B8=89=E7=9F=B3?= <49309820+zhaowendao30@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:07:50 +0800 Subject: [PATCH 47/88] Update lora_conversion_utils.py (#9980) x-flux single-blocks lora load Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- src/diffusers/loaders/lora_conversion_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index aab87b8f4dba..07c2c2272422 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -643,7 +643,11 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): old_state_dict, new_state_dict, old_key, - [f"transformer.single_transformer_blocks.{block_num}.norm.linear"], + [ + f"transformer.single_transformer_blocks.{block_num}.attn.to_q", + f"transformer.single_transformer_blocks.{block_num}.attn.to_k", + f"transformer.single_transformer_blocks.{block_num}.attn.to_v", + ], ) if "down" in old_key: From 0ed09a17bbab784a78fb163b557b4827467b0468 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 09:24:52 +0000 Subject: [PATCH 48/88] Check correct model type is passed to `from_pretrained` (#10189) * Check correct model type is passed to `from_pretrained` * Flax, skip scheduler * test_wrong_model * Fix for scheduler * Update tests/pipelines/test_pipelines.py Co-authored-by: Sayak Paul * EnumMeta * Flax * scheduler in expected types * make * type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name' * support union * fix typing in kandinsky * make * add LCMScheduler * 'LCMScheduler' object has no attribute 'sigmas' * tests for wrong scheduler * make * update * warning * tests * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Dhruv Nair * import FlaxSchedulerMixin * skip scheduler --------- Co-authored-by: Sayak Paul Co-authored-by: Dhruv Nair --- src/diffusers/pipelines/pipeline_utils.py | 22 ++++++++++++++++++++++ tests/pipelines/test_pipelines.py | 10 ++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a504184ea2f2..c505c5a262a3 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import fnmatch import importlib import inspect @@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # in this case they are already instantiated in `kwargs` # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + expected_types = pipeline_class._get_signature_types() passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -833,6 +835,26 @@ def load_module(name, value): init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + for key in init_dict.keys(): + if key not in passed_class_obj: + continue + if "scheduler" in key: + continue + + class_obj = passed_class_obj[key] + _expected_class_types = [] + for expected_type in expected_types[key]: + if isinstance(expected_type, enum.EnumMeta): + _expected_class_types.extend(expected_type.__members__.keys()) + else: + _expected_class_types.append(expected_type.__name__) + + _is_valid_type = class_obj.__class__.__name__ in _expected_class_types + if not _is_valid_type: + logger.warning( + f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." + ) + # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError( diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 43b01c40f5bb..423c82e0602e 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1802,6 +1802,16 @@ def test_pipe_same_device_id_offload(self): sd.maybe_free_model_hooks() assert sd._offload_gpu_id == 5 + def test_wrong_model(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + with self.assertRaises(ValueError) as error_context: + _ = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer + ) + + assert "is of type" in str(error_context.exception) + assert "but should be" in str(error_context.exception) + @slow @require_torch_gpu From 1826a1e7d31df48d345a20028b3ace48f09a4e60 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Thu, 19 Dec 2024 18:52:20 +0800 Subject: [PATCH 49/88] [LoRA] Support HunyuanVideo (#10254) * 1217 * 1217 * 1217 * update * reverse * add test * update test * make style * update * make style --------- Co-authored-by: Aryan --- src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 308 ++++++++++++++++++ src/diffusers/loaders/peft.py | 1 + .../transformers/transformer_hunyuan_video.py | 28 +- .../hunyuan_video/pipeline_hunyuan_video.py | 14 +- tests/lora/test_lora_layers_hunyuanvideo.py | 228 +++++++++++++ tests/lora/utils.py | 34 +- 7 files changed, 600 insertions(+), 15 deletions(-) create mode 100644 tests/lora/test_lora_layers_hunyuanvideo.py diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index b59150376599..6ea382d721de 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -70,6 +70,7 @@ def text_encoder_attn_modules(text_encoder): "FluxLoraLoaderMixin", "CogVideoXLoraLoaderMixin", "Mochi1LoraLoaderMixin", + "HunyuanVideoLoraLoaderMixin", "SanaLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] @@ -90,6 +91,7 @@ def text_encoder_attn_modules(text_encoder): AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, FluxLoraLoaderMixin, + HunyuanVideoLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Mochi1LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b8c44e480093..46d744233014 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3870,6 +3870,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) +class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`HunyuanVideoTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components) + + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index a791a250af08..9c00012ebc65 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -53,6 +53,7 @@ "FluxTransformer2DModel": lambda model_cls, weights: weights, "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, "MochiTransformer3DModel": lambda model_cls, weights: weights, + "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, "SanaTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 737be99c5a10..089389b5f9ad 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -19,7 +19,8 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor from ..embeddings import ( @@ -32,6 +33,9 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class HunyuanVideoAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -496,7 +500,7 @@ def forward( return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). @@ -670,8 +674,24 @@ def forward( encoder_attention_mask: torch.Tensor, pooled_projections: torch.Tensor, guidance: torch.Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + batch_size, num_channels, num_frames, height, width = hidden_states.shape p, p_t = self.config.patch_size, self.config.patch_size_t post_patch_num_frames = num_frames // p_t @@ -757,6 +777,10 @@ def custom_forward(*inputs): hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (hidden_states,) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index bd3d3c1e8485..4423ccf97932 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -20,6 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring @@ -132,7 +133,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class HunyuanVideoPipeline(DiffusionPipeline): +class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): r""" Pipeline for text-to-video generation using HunyuanVideo. @@ -447,6 +448,10 @@ def guidance_scale(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -471,6 +476,7 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -525,6 +531,10 @@ def __call__( The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. @@ -562,6 +572,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False device = self._execution_device @@ -640,6 +651,7 @@ def __call__( encoder_attention_mask=prompt_attention_mask, pooled_projections=pooled_prompt_embeds, guidance=guidance, + attention_kwargs=attention_kwargs, return_dict=False, )[0] diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py new file mode 100644 index 000000000000..59464c052684 --- /dev/null +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -0,0 +1,228 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np +import pytest +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + floats_tensor, + is_torch_version, + require_peft_backend, + skip_mps, + torch_device, +) + + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 + + +@require_peft_backend +@skip_mps +class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = HunyuanVideoPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_classes = [FlowMatchEulerDiscreteScheduler] + scheduler_kwargs = {} + + transformer_kwargs = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 10, + "num_layers": 1, + "num_single_layers": 1, + "num_refiner_layers": 1, + "patch_size": 1, + "patch_size_t": 1, + "guidance_embeds": True, + "text_embed_dim": 16, + "pooled_projection_dim": 8, + "rope_axes_dim": (2, 4, 4), + } + transformer_cls = HunyuanVideoTransformer3DModel + vae_kwargs = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "down_block_types": ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + "up_block_types": ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + "block_out_channels": (8, 8, 8, 8), + "layers_per_block": 1, + "act_fn": "silu", + "norm_num_groups": 4, + "scaling_factor": 0.476986, + "spatial_compression_ratio": 8, + "temporal_compression_ratio": 4, + "mid_block_add_attention": True, + } + vae_cls = AutoencoderKLHunyuanVideo + has_two_text_encoders = True + tokenizer_cls, tokenizer_id, tokenizer_subfolder = ( + LlamaTokenizerFast, + "hf-internal-testing/tiny-random-hunyuanvideo", + "tokenizer", + ) + tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = ( + CLIPTokenizer, + "hf-internal-testing/tiny-random-hunyuanvideo", + "tokenizer_2", + ) + text_encoder_cls, text_encoder_id, text_encoder_subfolder = ( + LlamaModel, + "hf-internal-testing/tiny-random-hunyuanvideo", + "text_encoder", + ) + text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = ( + CLIPTextModel, + "hf-internal-testing/tiny-random-hunyuanvideo", + "text_encoder_2", + ) + + @property + def output_shape(self): + return (1, 9, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + num_frames = 9 + num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1 + sizes = (4, 4) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "", + "num_frames": num_frames, + "num_inference_steps": 1, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "max_sequence_length": sequence_length, + "prompt_template": {"template": "{}", "crop_start": 0}, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + @pytest.mark.xfail( + condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), + reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", + strict=True, + ) + def test_lora_fuse_nan(self): + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) + + out = pipe( + prompt=inputs["prompt"], + height=inputs["height"], + width=inputs["width"], + num_frames=inputs["num_frames"], + num_inference_steps=inputs["num_inference_steps"], + max_sequence_length=inputs["max_sequence_length"], + output_type="np", + )[0] + + self.assertTrue(np.isnan(out).all()) + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + # TODO(aryan): Fix the following test + @unittest.skip("This test fails with an error I haven't been able to debug yet.") + def test_simple_inference_save_pretrained(self): + pass + + @unittest.skip("Not supported in HunyuanVideo.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in HunyuanVideo.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in HunyuanVideo.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_save_load(self): + pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index ac7a944cd026..73ed17049c1b 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -89,12 +89,12 @@ class PeftLoraLoaderMixinTests: has_two_text_encoders = False has_three_text_encoders = False - text_encoder_cls, text_encoder_id = None, None - text_encoder_2_cls, text_encoder_2_id = None, None - text_encoder_3_cls, text_encoder_3_id = None, None - tokenizer_cls, tokenizer_id = None, None - tokenizer_2_cls, tokenizer_2_id = None, None - tokenizer_3_cls, tokenizer_3_id = None, None + text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, None + text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, None + text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, None + tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, None + tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, None + tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, None unet_kwargs = None transformer_cls = None @@ -124,16 +124,26 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): torch.manual_seed(0) vae = self.vae_cls(**self.vae_kwargs) - text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id) - tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id) + text_encoder = self.text_encoder_cls.from_pretrained( + self.text_encoder_id, subfolder=self.text_encoder_subfolder + ) + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder) if self.text_encoder_2_cls is not None: - text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id) - tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id) + text_encoder_2 = self.text_encoder_2_cls.from_pretrained( + self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder + ) + tokenizer_2 = self.tokenizer_2_cls.from_pretrained( + self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder + ) if self.text_encoder_3_cls is not None: - text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id) - tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id) + text_encoder_3 = self.text_encoder_3_cls.from_pretrained( + self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder + ) + tokenizer_3 = self.tokenizer_3_cls.from_pretrained( + self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder + ) text_lora_config = LoraConfig( r=rank, From 9764f229d4a8386b4602711d0da5a4b02d9aa791 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 19 Dec 2024 22:20:40 +0530 Subject: [PATCH 50/88] [Single File] Add single file support for Mochi Transformer (#10268) update --- src/diffusers/loaders/single_file_model.py | 5 + src/diffusers/loaders/single_file_utils.py | 109 ++++++++++++++++++ .../models/transformers/transformer_mochi.py | 3 +- 3 files changed, 116 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 9641435fa5a6..d102282025c7 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -32,6 +32,7 @@ convert_ldm_vae_checkpoint, convert_ltx_transformer_checkpoint_to_diffusers, convert_ltx_vae_checkpoint_to_diffusers, + convert_mochi_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, create_controlnet_diffusers_config_from_ldm, @@ -96,6 +97,10 @@ "default_subfolder": "vae", }, "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers}, + "MochiTransformer3DModel": { + "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ded466b35e9a..8b2bf12214cd 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -106,6 +106,7 @@ ], "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", "autoencoder-dc-sana": "encoder.project_in.conv.bias", + "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -159,6 +160,7 @@ "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, + "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, } # Use to configure model sample size when original config is provided @@ -618,6 +620,9 @@ def infer_diffusers_model_type(checkpoint): else: model_type = "autoencoder-dc-f128c512" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): + model_type = "mochi-1-preview" + else: model_type = "v1" @@ -1758,6 +1763,12 @@ def swap_scale_shift(weight, dim): return new_weight +def swap_proj_gate(weight): + proj, gate = weight.chunk(2, dim=0) + new_weight = torch.cat([gate, proj], dim=0) + return new_weight + + def get_attn2_layers(state_dict): attn2_layers = [] for key in state_dict.keys(): @@ -2414,3 +2425,101 @@ def remap_proj_conv_(key: str, state_dict): handler_fn_inplace(key, converted_state_dict) return converted_state_dict + + +def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + new_state_dict = {} + + # Comfy checkpoints add this prefix + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + # Convert patch_embed + new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Convert time_embed + new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") + new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") + new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") + new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") + new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") + new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") + new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") + new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") + new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") + new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") + + # Convert transformer blocks + num_layers = 48 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"blocks.{i}." + + # norm1 + new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") + new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") + if i < num_layers - 1: + new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight") + new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") + else: + new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") + + # Visual attention + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight") + new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight") + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight") + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") + + # Context attention + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q + new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k + new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v + new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( + old_prefix + "attn.q_norm_y.weight" + ) + new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( + old_prefix + "attn.k_norm_y.weight" + ) + if i < num_layers - 1: + new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( + old_prefix + "attn.proj_y.weight" + ) + new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias") + + # MLP + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( + checkpoint.pop(old_prefix + "mlp_x.w1.weight") + ) + new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") + if i < num_layers - 1: + new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( + checkpoint.pop(old_prefix + "mlp_y.w1.weight") + ) + new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight") + + # Output layers + new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) + new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) + new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + + new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") + + return new_state_dict diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index fe72dc56883e..41e5289f2d57 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -20,6 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward @@ -304,7 +305,7 @@ def forward( @maybe_allow_in_graph -class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). From 3ee966950b636bcb9a78cc107da7887f195ac1a2 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 19 Dec 2024 22:34:44 +0530 Subject: [PATCH 51/88] Allow Mochi Transformer to be split across multiple GPUs (#10300) update --- src/diffusers/models/transformers/transformer_mochi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 41e5289f2d57..8763ea450253 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -335,6 +335,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri """ _supports_gradient_checkpointing = True + _no_split_modules = ["MochiTransformerBlock"] @register_to_config def __init__( From 074798b2997a6f1a329924b400a0db924e8e6735 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 17:04:57 +0000 Subject: [PATCH 52/88] Fix `local_files_only` for checkpoints with shards (#10294) --- src/diffusers/utils/hub_utils.py | 67 ++++++++++++++------------------ 1 file changed, 29 insertions(+), 38 deletions(-) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index ef4715ee0e1e..a6dfe18433e3 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -455,48 +455,39 @@ def _get_checkpoint_shard_files( allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] ignore_patterns = ["*.json", "*.md"] - if not local_files_only: - # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) - for shard_file in original_shard_filenames: - shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) - if not shard_file_present: - raise EnvironmentError( - f"{shards_path} does not appear to have a file named {shard_file} which is " - "required according to the checkpoint index." - ) - - try: - # Load from URL - cached_folder = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - user_agent=user_agent, - ) - if subfolder is not None: - cached_folder = os.path.join(cached_folder, subfolder) - - # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so - # we don't have to catch them here. We have also dealt with EntryNotFoundError. - except HTTPError as e: + # `model_info` call must guarded with the above condition. + model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if not shard_file_present: raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" - " again after checking your internet connection." - ) from e + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) - # If `local_files_only=True`, `cached_folder` may not contain all the shard files. - elif local_files_only: - _check_if_shards_exist_locally( - local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames + try: + # Load from URL + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, ) if subfolder is not None: - cached_folder = os.path.join(cache_dir, subfolder) + cached_folder = os.path.join(cached_folder, subfolder) + + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so + # we don't have to catch them here. We have also dealt with EntryNotFoundError. + except HTTPError as e: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e return cached_folder, sharded_metadata From d8825e7697d2ac982046f96652261a60596c4944 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 20 Dec 2024 02:35:41 +0530 Subject: [PATCH 53/88] Fix failing lora tests after HunyuanVideo lora (#10307) fix --- tests/lora/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 73ed17049c1b..0a0366fd8d2b 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -89,12 +89,12 @@ class PeftLoraLoaderMixinTests: has_two_text_encoders = False has_three_text_encoders = False - text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, None - text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, None - text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, None - tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, None - tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, None - tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, None + text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, "" + text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, "" + text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, "" + tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" + tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" + tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" unet_kwargs = None transformer_cls = None From b756ec6e80b3d94c3ae7dc356bdbbdb426a05dca Mon Sep 17 00:00:00 2001 From: djm <92705171+Foundsheep@users.noreply.github.com> Date: Fri, 20 Dec 2024 07:24:18 +0900 Subject: [PATCH 54/88] unet's `sample_size` attribute is to accept tuple(h, w) in `StableDiffusionPipeline` (#10181) --- .../models/unets/unet_2d_condition.py | 2 +- .../pipeline_stable_diffusion.py | 21 ++++++++++++++++--- .../stable_diffusion/test_stable_diffusion.py | 8 +++++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 4f55df32b738..e488f5897ebc 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -170,7 +170,7 @@ class conditioning with `class_embed_type` equal to `None`. @register_to_config def __init__( self, - sample_size: Optional[int] = None, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 4fd6a43a955a..ac6c8253e432 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -255,7 +255,12 @@ def __init__( is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") + and self._is_unet_config_sample_size_int + and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -902,8 +907,18 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor + if not height or not width: + height = ( + self.unet.config.sample_size + if self._is_unet_config_sample_size_int + else self.unet.config.sample_size[0] + ) + width = ( + self.unet.config.sample_size + if self._is_unet_config_sample_size_int + else self.unet.config.sample_size[1] + ) + height, width = height * self.vae_scale_factor, width * self.vae_scale_factor # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index f37d598c8387..ccd5567106d2 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -840,6 +840,14 @@ def callback_on_step_end(pipe, i, t, callback_kwargs): # they should be the same assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4) + def test_pipeline_accept_tuple_type_unet_sample_size(self): + # the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size + sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" + sample_size = [60, 80] + customised_unet = UNet2DConditionModel(sample_size=sample_size) + pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet) + assert pipe.unet.config.sample_size == sample_size + @slow @require_torch_gpu From 648d968cfc69074eaf51df3d337100f9805b030e Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:45:45 -0800 Subject: [PATCH 55/88] Enable Gradient Checkpointing for UNet2DModel (New) (#7201) * Port UNet2DModel gradient checkpointing code from #6718. --------- Co-authored-by: Sayak Paul Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen Co-authored-by: Dhruv Nair Co-authored-by: hlky --- src/diffusers/models/unets/unet_2d.py | 6 ++ src/diffusers/models/unets/unet_2d_blocks.py | 83 +++++++++++++++++-- .../versatile_diffusion/modeling_text_unet.py | 29 ++++++- .../test_models_autoencoder_kl.py | 2 +- ..._models_autoencoder_kl_temporal_decoder.py | 2 +- tests/models/test_modeling_common.py | 4 +- tests/models/unets/test_models_unet_2d.py | 42 ++++++++++ 7 files changed, 154 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index d05af686dede..bec62ce5cf45 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): conditioning with `class_embed_type` equal to `None`. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -241,6 +243,10 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index b9d186ac1aa6..b4e0cea7c71d 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -731,12 +731,35 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states, temb=temb) - hidden_states = resnet(hidden_states, temb) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1116,6 +1139,8 @@ def __init__( else: self.downsamplers = None + self.gradient_checkpointing = False + def forward( self, hidden_states: torch.Tensor, @@ -1130,9 +1155,30 @@ def forward( output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, **cross_attention_kwargs) - output_states = output_states + (hidden_states,) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: @@ -2354,6 +2400,7 @@ def __init__( else: self.upsamplers = None + self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( @@ -2375,8 +2422,28 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn(hidden_states) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 107a5a45bfa2..0fd8875a88a1 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -2223,12 +2223,35 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states, temb=temb) - hidden_states = resnet(hidden_states, temb) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py index 52bf5aba204b..c584bdcf56a2 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl.py @@ -146,7 +146,7 @@ def test_enable_disable_slicing(self): ) def test_gradient_checkpointing_is_applied(self): - expected_set = {"Decoder", "Encoder"} + expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) def test_from_pretrained_hub(self): diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py index 4308cb64896e..cf80ff50443e 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py @@ -65,7 +65,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_gradient_checkpointing_is_applied(self): - expected_set = {"Encoder", "TemporalDecoder"} + expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) @unittest.skip("Test unsupported.") diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index a7594f2ea13f..91a462d5878e 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -803,7 +803,7 @@ def test_enable_disable_gradient_checkpointing(self): self.assertFalse(model.is_gradient_checkpointing) @require_torch_accelerator_with_training - def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5): + def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}): if not self.model_class._supports_gradient_checkpointing: return # Skip test if model does not support gradient checkpointing @@ -850,6 +850,8 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_ for name, param in named_params.items(): if "post_quant_conv" in name: continue + if name in skip: + continue self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol)) @unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.") diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index 5f827f274224..ddf5f53511f7 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -105,6 +105,23 @@ def test_mid_block_attn_groups(self): expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "AttnUpBlock2D", + "AttnDownBlock2D", + "UNetMidBlock2D", + "UpBlock2D", + "DownBlock2D", + } + + # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` + attention_head_dim = 8 + block_out_channels = (16, 32) + + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels + ) + class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -220,6 +237,17 @@ def test_output_pretrained(self): self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"} + + # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` + attention_head_dim = 32 + block_out_channels = (32, 64) + + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels + ) + class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -329,3 +357,17 @@ def test_output_pretrained_ve_large(self): def test_forward_with_norm_groups(self): # not required for this model pass + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "UNetMidBlock2D", + } + + block_out_channels = (32, 64, 64, 64) + + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, block_out_channels=block_out_channels + ) + + def test_effective_gradient_checkpointing(self): + super().test_effective_gradient_checkpointing(skip={"time_proj.weight"}) From 319124847216a57a6ae12b567689aa72b28f1c02 Mon Sep 17 00:00:00 2001 From: Daniel Regado <35548192+guiyrt@users.noreply.github.com> Date: Fri, 20 Dec 2024 00:48:18 +0000 Subject: [PATCH 56/88] [WIP] SD3.5 IP-Adapter Pipeline Integration (#9987) * Added support for single IPAdapter on SD3.5 pipeline --------- Co-authored-by: hlky Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: YiYi Xu --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/attnprocessor.md | 2 + docs/source/en/api/loaders/ip_adapter.md | 6 + docs/source/en/api/loaders/transformer_sd3.md | 29 ++ .../stable_diffusion/stable_diffusion_3.md | 69 ++++- src/diffusers/loaders/__init__.py | 12 +- src/diffusers/loaders/ip_adapter.py | 251 +++++++++++++++++- src/diffusers/loaders/transformer_sd3.py | 89 +++++++ src/diffusers/models/attention_processor.py | 172 ++++++++++++ src/diffusers/models/embeddings.py | 181 +++++++++++++ .../models/transformers/transformer_sd3.py | 16 +- .../pipeline_stable_diffusion_3.py | 129 ++++++++- .../test_pipeline_stable_diffusion_3.py | 2 + 13 files changed, 935 insertions(+), 25 deletions(-) create mode 100644 docs/source/en/api/loaders/transformer_sd3.md create mode 100644 src/diffusers/loaders/transformer_sd3.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 27e9fe5e191b..6ac66db73026 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -238,6 +238,8 @@ title: Textual Inversion - local: api/loaders/unet title: UNet + - local: api/loaders/transformer_sd3 + title: SD3Transformer2D - local: api/loaders/peft title: PEFT title: Loaders diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index fee0d7e35764..8bdffc330567 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -86,6 +86,8 @@ An attention processor is a class for applying different types of attention mech [[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0 +[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0 + ## JointAttnProcessor2_0 [[autodoc]] models.attention_processor.JointAttnProcessor2_0 diff --git a/docs/source/en/api/loaders/ip_adapter.md b/docs/source/en/api/loaders/ip_adapter.md index a10f30ef8e5b..946a8b1af875 100644 --- a/docs/source/en/api/loaders/ip_adapter.md +++ b/docs/source/en/api/loaders/ip_adapter.md @@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading] [[autodoc]] loaders.ip_adapter.IPAdapterMixin +## SD3IPAdapterMixin + +[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin + - all + - is_ip_adapter_active + ## IPAdapterMaskProcessor [[autodoc]] image_processor.IPAdapterMaskProcessor \ No newline at end of file diff --git a/docs/source/en/api/loaders/transformer_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md new file mode 100644 index 000000000000..4fc9603054b4 --- /dev/null +++ b/docs/source/en/api/loaders/transformer_sd3.md @@ -0,0 +1,29 @@ + + +# SD3Transformer2D + +This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead. + +The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs. + + + +To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide. + + + +## SD3Transformer2DLoadersMixin + +[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin + - all + - _load_ip_adapter_weights \ No newline at end of file diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index 8170c5280d38..eb67964ab0bd 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -59,9 +59,76 @@ image.save("sd3_hello_world.png") - [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large) - [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo) +## Image Prompting with IP-Adapters + +An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need: + +- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder. +- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`. +- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection. + +IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally. + +```python +import torch +from PIL import Image + +from diffusers import StableDiffusion3Pipeline +from transformers import SiglipVisionModel, SiglipImageProcessor + +image_encoder_id = "google/siglip-so400m-patch14-384" +ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter" + +feature_extractor = SiglipImageProcessor.from_pretrained( + image_encoder_id, + torch_dtype=torch.float16 +) +image_encoder = SiglipVisionModel.from_pretrained( + image_encoder_id, + torch_dtype=torch.float16 +).to( "cuda") + +pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-large", + torch_dtype=torch.float16, + feature_extractor=feature_extractor, + image_encoder=image_encoder, +).to("cuda") + +pipe.load_ip_adapter(ip_adapter_id) +pipe.set_ip_adapter_scale(0.6) + +ref_img = Image.open("image.jpg").convert('RGB') + +image = pipe( + width=1024, + height=1024, + prompt="a cat", + negative_prompt="lowres, low quality, worst quality", + num_inference_steps=24, + guidance_scale=5.0, + ip_adapter_image=ref_img +).images[0] + +image.save("result.jpg") +``` + +
+ +
IP-Adapter examples with prompt "a cat"
+
+ + + + +Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work. + + + + ## Memory Optimisations for SD3 -SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. +SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. ### Running Inference with Model Offloading diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 6ea382d721de..c7ea0be55db2 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -56,6 +56,7 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] + _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] if is_transformers_available(): @@ -74,7 +75,10 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] - _import_structure["ip_adapter"] = ["IPAdapterMixin"] + _import_structure["ip_adapter"] = [ + "IPAdapterMixin", + "SD3IPAdapterMixin", + ] _import_structure["peft"] = ["PeftAdapterMixin"] @@ -82,11 +86,15 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin + from .transformer_sd3 import SD3Transformer2DLoadersMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers if is_transformers_available(): - from .ip_adapter import IPAdapterMixin + from .ip_adapter import ( + IPAdapterMixin, + SD3IPAdapterMixin, + ) from .lora_pipeline import ( AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index ca460f948e6f..11ce4f1634d7 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -33,15 +33,18 @@ if is_transformers_available(): - from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel + +from ..models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + IPAdapterXFormersAttnProcessor, + JointAttnProcessor2_0, + SD3IPAdapterJointAttnProcessor2_0, +) - from ..models.attention_processor import ( - AttnProcessor, - AttnProcessor2_0, - IPAdapterAttnProcessor, - IPAdapterAttnProcessor2_0, - IPAdapterXFormersAttnProcessor, - ) logger = logging.get_logger(__name__) @@ -348,3 +351,235 @@ def unload_ip_adapter(self): else value.__class__() ) self.unet.set_attn_processor(attn_procs) + + +class SD3IPAdapterMixin: + """Mixin for handling StableDiffusion 3 IP Adapters.""" + + @property + def is_ip_adapter_active(self) -> bool: + """Checks if IP-Adapter is loaded and scale > 0. + + IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0, + the image context is irrelevant. + + Returns: + `bool`: True when IP-Adapter is loaded and any layer has scale > 0. + """ + scales = [ + attn_proc.scale + for attn_proc in self.transformer.attn_processors.values() + if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0) + ] + + return len(scales) > 0 and any(scale > 0 for scale in scales) + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + weight_name: str = "ip-adapter.safetensors", + subfolder: Optional[str] = None, + image_encoder_folder: Optional[str] = "image_encoder", + **kwargs, + ) -> None: + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + weight_name (`str`, defaults to "ip-adapter.safetensors"): + The name of the weight file to load. If a list is passed, it should have the same length as + `subfolder`. + subfolder (`str`, *optional*): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): + The subfolder location of the image encoder within a larger model repository on the Hub or locally. + Pass `None` to not load the image encoder. If the image encoder is located in a folder inside + `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. + `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than + `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, + `image_encoder_folder="different_subfolder/image_encoder"`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + # Load the main state dict first + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if image_encoder_folder is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + if image_encoder_folder.count("/") == 0: + image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix() + else: + image_encoder_subfolder = Path(image_encoder_folder).as_posix() + + # Commons args for loading image encoder and image processor + kwargs = { + "low_cpu_mem_usage": low_cpu_mem_usage, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + + self.register_modules( + feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to( + self.device, dtype=self.dtype + ), + image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to( + self.device, dtype=self.dtype + ), + ) + else: + raise ValueError( + "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." + ) + else: + logger.warning( + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + ) + + # Load IP-Adapter into transformer + self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage) + + def set_ip_adapter_scale(self, scale: float) -> None: + """ + Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only + conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages + the model to produce more diverse images, but they may not be as aligned with the image prompt. + + Example: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.set_ip_adapter_scale(0.6) + >>> ... + ``` + + Args: + scale (float): + IP-Adapter scale to be set. + + """ + for attn_processor in self.transformer.attn_processors.values(): + if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0): + attn_processor.scale = scale + + def unload_ip_adapter(self) -> None: + """ + Unloads the IP Adapter weights. + + Example: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # Remove image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=None) + + # Remove feature extractor + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=None) + + # Remove image projection + self.transformer.image_proj = None + + # Restore original attention processors layers + attn_procs = { + name: ( + JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__() + ) + for name, value in self.transformer.attn_processors.items() + } + self.transformer.set_attn_processor(attn_procs) diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py new file mode 100644 index 000000000000..435d1da06ca1 --- /dev/null +++ b/src/diffusers/loaders/transformer_sd3.py @@ -0,0 +1,89 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 +from ..models.embeddings import IPAdapterTimeImageProjection +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta + + +class SD3Transformer2DLoadersMixin: + """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`.""" + + def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None: + """Sets IP-Adapter attention processors, image projection, and loads state_dict. + + Args: + state_dict (`Dict`): + State dict with keys "ip_adapter", which contains parameters for attention processors, and + "image_proj", which contains parameters for image projection net. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + # IP-Adapter cross attention parameters + hidden_size = self.config.attention_head_dim * self.config.num_attention_heads + ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads + timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1] + + # Dict where key is transformer layer index, value is attention processor's state dict + # ip_adapter state dict keys example: "0.norm_ip.linear.weight" + layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))} + for key, weights in state_dict["ip_adapter"].items(): + idx, name = key.split(".", maxsplit=1) + layer_state_dict[int(idx)][name] = weights + + # Create IP-Adapter attention processor + attn_procs = {} + for idx, name in enumerate(self.attn_processors.keys()): + attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, + ip_hidden_states_dim=ip_hidden_states_dim, + head_dim=self.config.attention_head_dim, + timesteps_emb_dim=timesteps_emb_dim, + ).to(self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) + else: + load_model_dict_into_meta( + attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype + ) + + self.set_attn_processor(attn_procs) + + # Image projetion parameters + embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1] + output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0] + hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0] + heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64 + num_queries = state_dict["image_proj"]["latents"].shape[1] + timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1] + + # Image projection + self.image_proj = IPAdapterTimeImageProjection( + embed_dim=embed_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + heads=heads, + num_queries=num_queries, + timestep_in_dim=timestep_in_dim, + ).to(device=self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) + else: + load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 05cbaa40e693..ed0dd4f71d27 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5243,6 +5243,177 @@ def __call__( return hidden_states +class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module): + """ + Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with + additional image-based information and timestep embeddings. + + Args: + hidden_size (`int`): + The number of hidden channels. + ip_hidden_states_dim (`int`): + The image feature dimension. + head_dim (`int`): + The number of head channels. + timesteps_emb_dim (`int`, defaults to 1280): + The number of input channels for timestep embedding. + scale (`float`, defaults to 0.5): + IP-Adapter scale. + """ + + def __init__( + self, + hidden_size: int, + ip_hidden_states_dim: int, + head_dim: int, + timesteps_emb_dim: int = 1280, + scale: float = 0.5, + ): + super().__init__() + + # To prevent circular import + from .normalization import AdaLayerNorm, RMSNorm + + self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1) + self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) + self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) + self.norm_q = RMSNorm(head_dim, 1e-6) + self.norm_k = RMSNorm(head_dim, 1e-6) + self.norm_ip_k = RMSNorm(head_dim, 1e-6) + self.scale = scale + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + ip_hidden_states: torch.FloatTensor = None, + temb: torch.FloatTensor = None, + ) -> torch.FloatTensor: + """ + Perform the attention computation, integrating image features (if provided) and timestep embeddings. + + If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0. + + Args: + attn (`Attention`): + Attention instance. + hidden_states (`torch.FloatTensor`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + The encoder hidden states. + attention_mask (`torch.FloatTensor`, *optional*): + Attention mask. + ip_hidden_states (`torch.FloatTensor`, *optional*): + Image embeddings. + temb (`torch.FloatTensor`, *optional*): + Timestep embeddings. + + Returns: + `torch.FloatTensor`: Output hidden states. + """ + residual = hidden_states + + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + img_query = query + img_key = key + img_value = value + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP Adapter + if self.scale != 0 and ip_hidden_states is not None: + # Norm image features + norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb) + + # To k and v + ip_key = self.to_k_ip(norm_ip_hidden_states) + ip_value = self.to_v_ip(norm_ip_hidden_states) + + # Reshape + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Norm + query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + ip_key = self.norm_ip_k(ip_key) + + # cat img + key = torch.cat([img_key, ip_key], dim=2) + value = torch.cat([img_value, ip_value], dim=2) + + ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + ip_hidden_states * self.scale + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class PAGIdentitySelfAttnProcessor2_0: r""" Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -5772,6 +5943,7 @@ def __call__( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, + SD3IPAdapterJointAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, LoRAAttnProcessor, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 69b3ee8466f4..f1b339e6180b 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2396,6 +2396,187 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: return out +class IPAdapterTimeImageProjectionBlock(nn.Module): + """Block for IPAdapterTimeImageProjection. + + Args: + hidden_dim (`int`, defaults to 1280): + The number of hidden channels. + dim_head (`int`, defaults to 64): + The number of head channels. + heads (`int`, defaults to 20): + Parallel attention heads. + ffn_ratio (`int`, defaults to 4): + The expansion ratio of feedforward network hidden layer channels. + """ + + def __init__( + self, + hidden_dim: int = 1280, + dim_head: int = 64, + heads: int = 20, + ffn_ratio: int = 4, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.ln0 = nn.LayerNorm(hidden_dim) + self.ln1 = nn.LayerNorm(hidden_dim) + self.attn = Attention( + query_dim=hidden_dim, + cross_attention_dim=hidden_dim, + dim_head=dim_head, + heads=heads, + bias=False, + out_bias=False, + ) + self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False) + + # AdaLayerNorm + self.adaln_silu = nn.SiLU() + self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim) + self.adaln_norm = nn.LayerNorm(hidden_dim) + + # Set attention scale and fuse KV + self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head)) + self.attn.fuse_projections() + self.attn.to_k = None + self.attn.to_v = None + + def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x (`torch.Tensor`): + Image features. + latents (`torch.Tensor`): + Latent features. + timestep_emb (`torch.Tensor`): + Timestep embedding. + + Returns: + `torch.Tensor`: Output latent features. + """ + + # Shift and scale for AdaLayerNorm + emb = self.adaln_proj(self.adaln_silu(timestep_emb)) + shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1) + + # Fused Attention + residual = latents + x = self.ln0(x) + latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None] + + batch_size = latents.shape[0] + + query = self.attn.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.attn.heads + + query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + + weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + latents = weight @ value + + latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim) + latents = self.attn.to_out[0](latents) + latents = self.attn.to_out[1](latents) + latents = latents + residual + + ## FeedForward + residual = latents + latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + return self.ff(latents) + residual + + +# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +class IPAdapterTimeImageProjection(nn.Module): + """Resampler of SD3 IP-Adapter with timestep embedding. + + Args: + embed_dim (`int`, defaults to 1152): + The feature dimension. + output_dim (`int`, defaults to 2432): + The number of output channels. + hidden_dim (`int`, defaults to 1280): + The number of hidden channels. + depth (`int`, defaults to 4): + The number of blocks. + dim_head (`int`, defaults to 64): + The number of head channels. + heads (`int`, defaults to 20): + Parallel attention heads. + num_queries (`int`, defaults to 64): + The number of queries. + ffn_ratio (`int`, defaults to 4): + The expansion ratio of feedforward network hidden layer channels. + timestep_in_dim (`int`, defaults to 320): + The number of input channels for timestep embedding. + timestep_flip_sin_to_cos (`bool`, defaults to True): + Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False). + timestep_freq_shift (`int`, defaults to 0): + Controls the timestep delta between frequencies between dimensions. + """ + + def __init__( + self, + embed_dim: int = 1152, + output_dim: int = 2432, + hidden_dim: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 20, + num_queries: int = 64, + ffn_ratio: int = 4, + timestep_in_dim: int = 320, + timestep_flip_sin_to_cos: bool = True, + timestep_freq_shift: int = 0, + ) -> None: + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5) + self.proj_in = nn.Linear(embed_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + self.layers = nn.ModuleList( + [IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift) + self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu") + + def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + + Args: + x (`torch.Tensor`): + Image features. + timestep (`torch.Tensor`): + Timestep in denoising process. + Returns: + `Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb). + """ + timestep_emb = self.time_proj(timestep).to(dtype=x.dtype) + timestep_emb = self.time_embedding(timestep_emb) + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + x = x + timestep_emb[:, None] + + for block in self.layers: + latents = block(x, latents, timestep_emb) + + latents = self.proj_out(latents) + latents = self.norm_out(latents) + + return latents, timestep_emb + + class MultiIPAdapterImageProjection(nn.Module): def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): super().__init__() diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 79c4069e9a37..415540ef7f6a 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...models.attention import FeedForward, JointTransformerBlock from ...models.attention_processor import ( Attention, @@ -103,7 +103,9 @@ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): return hidden_states -class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class SD3Transformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin +): """ The Transformer model introduced in Stable Diffusion 3. @@ -349,8 +351,8 @@ def forward( Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of input conditions. timestep (`torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states (`list` of `torch.Tensor`): @@ -390,6 +392,12 @@ def forward( temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep) + + joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb) + for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers is_skip = True if skip_layers is not None and index_block in skip_layers else False diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 0a51dcbc1261..a53d786798ca 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -1,4 +1,4 @@ -# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,14 +17,16 @@ import torch from transformers import ( + BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, + PreTrainedModel, T5EncoderModel, T5TokenizerFast, ) -from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -142,7 +144,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): +class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): r""" Args: transformer ([`SD3Transformer2DModel`]): @@ -174,10 +176,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + image_encoder (`PreTrainedModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`BaseImageProcessor`, *optional*): + Image processor for IP Adapter. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" - _optional_components = [] + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( @@ -191,6 +197,8 @@ def __init__( tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, + image_encoder: PreTrainedModel = None, + feature_extractor: BaseImageProcessor = None, ): super().__init__() @@ -204,6 +212,8 @@ def __init__( tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -683,6 +693,83 @@ def num_timesteps(self): def interrupt(self): return self._interrupt + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image + def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: + """Encodes the given image into a feature representation using a pre-trained image encoder. + + Args: + image (`PipelineImageInput`): + Input image to be encoded. + device: (`torch.device`): + Torch device. + + Returns: + `torch.Tensor`: The encoded image feature representation. + """ + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=self.dtype) + + return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + ) -> torch.Tensor: + """Prepares image embeddings for use in the IP-Adapter. + + Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. + + Args: + ip_adapter_image (`PipelineImageInput`, *optional*): + The input image to extract features from for IP-Adapter. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Precomputed image embeddings. + device: (`torch.device`, *optional*): + Torch device. + num_images_per_prompt (`int`, defaults to 1): + Number of images that should be generated per prompt. + do_classifier_free_guidance (`bool`, defaults to True): + Whether to use classifier free guidance or not. + """ + device = device or self._execution_device + + if ip_adapter_image_embeds is not None: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + elif ip_adapter_image is not None: + single_image_embeds = self.encode_image(ip_adapter_image, device) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.zeros_like(single_image_embeds) + else: + raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") + + image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + + return image_embeds.to(device=device) + + def enable_sequential_cpu_offload(self, *args, **kwargs): + if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: + logger.warning( + "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " + "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " + "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." + ) + + super().enable_sequential_cpu_offload(*args, **kwargs) + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -705,6 +792,8 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -713,9 +802,9 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, skip_guidance_layers: List[int] = None, - skip_layer_guidance_scale: int = 2.8, - skip_layer_guidance_stop: int = 0.2, - skip_layer_guidance_start: int = 0.01, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_start: float = 0.01, mu: Optional[float] = None, ): r""" @@ -781,6 +870,11 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, + emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to + `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -938,7 +1032,22 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # 6. Denoising loop + # 6. Prepare image embeddings + if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} + else: + self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + + # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 07ce5487f256..a6f718ae4fbb 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -103,6 +103,8 @@ def get_dummy_components(self): "tokenizer_3": tokenizer_3, "transformer": transformer, "vae": vae, + "image_encoder": None, + "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0): From 41ba8c0bf6b3dc3ebd0fa6b96ecf671fa4171566 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 20 Dec 2024 07:12:20 +0530 Subject: [PATCH 57/88] Add support for sharded models when TorchAO quantization is enabled (#10256) * add sharded + device_map check --- src/diffusers/models/modeling_utils.py | 2 +- tests/quantization/torchao/test_torchao.py | 70 +++++++++++++++------- 2 files changed, 48 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0f9c9203c926..872d4d73d41f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -802,7 +802,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder or "", ) - if hf_quantizer is not None: + if hf_quantizer is not None and is_bnb_quantization_method: model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") is_sharded = False diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 58c1d3613daf..6f9980c006ac 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -278,13 +278,14 @@ def test_int4wo_quant_bfloat16_conversion(self): self.assertEqual(weight.quant_max, 15) self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) - def test_offload(self): + def test_device_map(self): """ - Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies - that the device map is correctly set (in the `hf_device_map` attribute of the model). + Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps. + The custom device map performs cpu/disk offloading as well. Also verifies that the device map is + correctly set (in the `hf_device_map` attribute of the model). """ - device_map_offload = { + custom_device_map_dict = { "time_text_embed": torch_device, "context_embedder": torch_device, "x_embedder": torch_device, @@ -293,27 +294,50 @@ def test_offload(self): "norm_out": torch_device, "proj_out": "cpu", } + device_maps = ["auto", custom_device_map_dict] inputs = self.get_dummy_tensor_inputs(torch_device) - - with tempfile.TemporaryDirectory() as offload_folder: - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) - quantized_model = FluxTransformer2DModel.from_pretrained( - "hf-internal-testing/tiny-flux-pipe", - subfolder="transformer", - quantization_config=quantization_config, - device_map=device_map_offload, - torch_dtype=torch.bfloat16, - offload_folder=offload_folder, - ) - - self.assertTrue(quantized_model.hf_device_map == device_map_offload) - - output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - - expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) + + for device_map in device_maps: + device_map_to_compare = {"": 0} if device_map == "auto" else device_map + + # Test non-sharded model + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + # Test sharded model + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-sharded", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_modules_to_not_convert(self): quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) From 151b74cd7758df590c523230a86230ba3bbc786f Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 20 Dec 2024 11:45:37 +0530 Subject: [PATCH 58/88] Make tensors in ResNet contiguous for Hunyuan VAE (#10309) contiguous tensors in resnet Co-authored-by: YiYi Xu --- .../models/autoencoders/autoencoder_kl_hunyuan_video.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 5c1d94d4e18f..e2236a7f20ad 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -168,6 +168,7 @@ def __init__( self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.contiguous() residual = hidden_states hidden_states = self.norm1(hidden_states) From dbc1d505f018807089ea0da575f40ba22e8b4709 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 20 Dec 2024 11:52:29 +0530 Subject: [PATCH 59/88] [Single File] Add GGUF support for LTX (#10298) * update * add docs. --------- Co-authored-by: Sayak Paul --- docs/source/en/api/pipelines/ltx_video.md | 39 ++++++++++++++++++++++ src/diffusers/loaders/single_file_utils.py | 15 ++++----- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index ac2b1c95b5b1..211cd3007d1e 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -61,6 +61,45 @@ pipe = LTXImageToVideoPipeline.from_single_file( ) ``` +Loading [LTX GGUF checkpoints](https://huggingface.co/city96/LTX-Video-gguf) are also supported: + +```py +import torch +from diffusers.utils import export_to_video +from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig + +ckpt_path = ( + "https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf" +) +transformer = LTXVideoTransformer3DModel.from_single_file( + ckpt_path, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16, +) +pipe = LTXPipeline.from_pretrained( + "Lightricks/LTX-Video", + transformer=transformer, + generator=torch.manual_seed(0), + torch_dtype=torch.bfloat16, +) +pipe.enable_model_cpu_offload() + +prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=704, + height=480, + num_frames=161, + num_inference_steps=50, +).frames[0] +export_to_video(video, "output_gguf_ltx.mp4", fps=24) +``` + +Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support. + Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption. ## LTXPipeline diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 8b2bf12214cd..f1408c2c409b 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -99,10 +99,11 @@ "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", ], "ltx-video": [ - ( - "model.diffusion_model.patchify_proj.weight", - "model.diffusion_model.transformer_blocks.27.scale_shift_table", - ), + "model.diffusion_model.patchify_proj.weight", + "model.diffusion_model.transformer_blocks.27.scale_shift_table", + "patchify_proj.weight", + "transformer_blocks.27.scale_shift_table", + "vae.per_channel_statistics.mean-of-means", ], "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", "autoencoder-dc-sana": "encoder.project_in.conv.bias", @@ -601,7 +602,7 @@ def infer_diffusers_model_type(checkpoint): else: model_type = "flux-schnell" - elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]): + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]): model_type = "ltx-video" elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint: @@ -2266,9 +2267,7 @@ def swap_scale_shift(weight): def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = { - key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key - } + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key} TRANSFORMER_KEYS_RENAME_DICT = { "model.diffusion_model.": "", From 17128c42a4c7c0234f615b3e52b41ac0d1f70a58 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 20 Dec 2024 14:30:32 +0530 Subject: [PATCH 60/88] [LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill (#10259) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * lora expansion with dummy zeros. * updates * fix working 🥳 * working. * use torch.device meta for state dict expansion. * tests Co-authored-by: a-r-r-o-w * fixes * fixes * switch to debug * fix * Apply suggestions from code review Co-authored-by: Aryan * fix stuff * docs --------- Co-authored-by: a-r-r-o-w Co-authored-by: Aryan --- docs/source/en/api/pipelines/flux.md | 37 ++++++ src/diffusers/loaders/lora_pipeline.py | 137 ++++++++++++++++------- tests/lora/test_lora_layers_flux.py | 149 ++++++++++++++++++------- 3 files changed, 239 insertions(+), 84 deletions(-) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index af9c3639e047..080442efb0d1 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -268,6 +268,43 @@ images = pipe( images[0].save("flux-redux.png") ``` +## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux + +We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD). + +```py +from diffusers import FluxControlPipeline +from image_gen_aux import DepthPreprocessor +from diffusers.utils import load_image +from huggingface_hub import hf_hub_download +import torch + +control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) +control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth") +control_pipe.load_lora_weights( + hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd" +) +control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125]) +control_pipe.enable_model_cpu_offload() + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") +control_image = processor(control_image)[0].convert("RGB") + +image = control_pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=8, + guidance_scale=10.0, + generator=torch.Generator().manual_seed(42), +).images[0] +image.save("output.png") +``` + ## Running FP16 inference Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 46d744233014..e69681611a4a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1863,6 +1863,9 @@ def load_lora_weights( "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " "To get a comprehensive list of parameter names that were modified, enable debug logging." ) + transformer_lora_state_dict = self._maybe_expand_lora_state_dict( + transformer=transformer, lora_state_dict=transformer_lora_state_dict + ) if len(transformer_lora_state_dict) > 0: self.load_lora_into_transformer( @@ -2309,16 +2312,17 @@ def _maybe_expand_transformer_param_shape_or_error_( # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False - + is_peft_loaded = getattr(transformer, "peft_config", None) is not None for name, module in transformer.named_modules(): if isinstance(module, torch.nn.Linear): module_weight = module.weight.data module_bias = module.bias.data if module.bias is not None else None bias = module_bias is not None - lora_A_weight_name = f"{name}.lora_A.weight" - lora_B_weight_name = f"{name}.lora_B.weight" - if lora_A_weight_name not in state_dict.keys(): + lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name + lora_A_weight_name = f"{lora_base_name}.lora_A.weight" + lora_B_weight_name = f"{lora_base_name}.lora_B.weight" + if lora_A_weight_name not in state_dict: continue in_features = state_dict[lora_A_weight_name].shape[1] @@ -2329,56 +2333,105 @@ def _maybe_expand_transformer_param_shape_or_error_( continue module_out_features, module_in_features = module_weight.shape - if out_features < module_out_features or in_features < module_in_features: - raise NotImplementedError( - f"Only LoRAs with input/output features higher than the current module's input/output features " - f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which " - f"are lower than {module_in_features=} and {module_out_features=}. If you require support for " - f"this please open an issue at https://github.com/huggingface/diffusers/issues." + debug_message = "" + if in_features > module_in_features: + debug_message += ( + f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' + f"checkpoint contains higher number of features than expected. The number of input_features will be " + f"expanded from {module_in_features} to {in_features}" ) - - debug_message = ( - f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' - f"checkpoint contains higher number of features than expected. The number of input_features will be " - f"expanded from {module_in_features} to {in_features}" - ) - if module_out_features != out_features: + if out_features > module_out_features: debug_message += ( ", and the number of output features will be " f"expanded from {module_out_features} to {out_features}." ) else: debug_message += "." - logger.debug(debug_message) + if debug_message: + logger.debug(debug_message) + + if out_features > module_out_features or in_features > module_in_features: + has_param_with_shape_update = True + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + with torch.device("meta"): + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, dtype=module_weight.dtype + ) + # Only weights are expanded and biases are not. This is because only the input dimensions + # are changed while the output dimensions remain the same. The shape of the weight tensor + # is (out_features, in_features), while the shape of bias tensor is (out_features,), which + # explains the reason why only weights are expanded. + new_weight = torch.zeros_like( + expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + ) + slices = tuple(slice(0, dim) for dim in module_weight.shape) + new_weight[slices] = module_weight + tmp_state_dict = {"weight": new_weight} + if module_bias is not None: + tmp_state_dict["bias"] = module_bias + expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True) + + setattr(parent_module, current_module_name, expanded_module) + + del tmp_state_dict + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(expanded_module.weight.data.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) - has_param_with_shape_update = True - parent_module_name, _, current_module_name = name.rpartition(".") - parent_module = transformer.get_submodule(parent_module_name) + return has_param_with_shape_update - # TODO: consider initializing this under meta device for optims. - expanded_module = torch.nn.Linear( - in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype - ) - # Only weights are expanded and biases are not. - new_weight = torch.zeros_like( - expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + @classmethod + def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): + expanded_module_names = set() + transformer_state_dict = transformer.state_dict() + prefix = f"{cls.transformer_name}." + + lora_module_names = [ + key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight") + ] + lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)] + lora_module_names = sorted(set(lora_module_names)) + transformer_module_names = sorted({name for name, _ in transformer.named_modules()}) + unexpected_modules = set(lora_module_names) - set(transformer_module_names) + if unexpected_modules: + logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") + + is_peft_loaded = getattr(transformer, "peft_config", None) is not None + for k in lora_module_names: + if k in unexpected_modules: + continue + + base_param_name = ( + f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight" + ) + base_weight_param = transformer_state_dict[base_param_name] + lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] + + if base_weight_param.shape[1] > lora_A_param.shape[1]: + shape = (lora_A_param.shape[0], base_weight_param.shape[1]) + expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) + expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) + lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight + expanded_module_names.add(k) + elif base_weight_param.shape[1] < lora_A_param.shape[1]: + raise NotImplementedError( + f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." ) - slices = tuple(slice(0, dim) for dim in module_weight.shape) - new_weight[slices] = module_weight - expanded_module.weight.data.copy_(new_weight) - if module_bias is not None: - expanded_module.bias.data.copy_(module_bias) - - setattr(parent_module, current_module_name, expanded_module) - if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: - attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] - new_value = int(expanded_module.weight.data.shape[1]) - old_value = getattr(transformer.config, attribute_name) - setattr(transformer.config, attribute_name, new_value) - logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.") + if expanded_module_names: + logger.info( + f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." + ) - return has_param_with_shape_update + return lora_state_dict # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b28fdde91574..1378c048b868 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -340,21 +340,6 @@ def test_lora_parameter_expanded_shapes(self): self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - dummy_lora_A = torch.nn.Linear(1, rank, bias=False) - dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, - "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, - } - # We should error out because lora input features is less than original. We only - # support expanding the module, not shrinking it - with self.assertRaises(NotImplementedError): - pipe.load_lora_weights(lora_state_dict, "adapter-1") - @require_peft_version_greater("0.13.2") def test_lora_B_bias(self): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) @@ -430,10 +415,10 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) - def test_lora_expanding_shape_with_normal_lora_raises_error(self): - # TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but - # another lora with correct shapes is loaded. This is not supported at the moment and should raise an error. - # When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180 + def test_lora_expanding_shape_with_normal_lora(self): + # This test checks if it works when a lora with expanded shapes (like control loras) but + # another lora with correct shapes is loaded. The opposite direction isn't supported and is + # tested with it. components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) # Change the transformer config to mimic a real use case. @@ -478,27 +463,18 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self): "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - # The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct - # input features before expansion. This should raise an error about the weight shapes being incompatible. - self.assertRaisesRegex( - RuntimeError, - "size mismatch for x_embedder.lora_A.adapter-2.weight", - pipe.load_lora_weights, - lora_state_dict, - "adapter-2", - ) - # We should have `adapter-1` as the only adapter. - self.assertTrue(pipe.get_active_adapters() == ["adapter-1"]) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-2") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) + self.assertTrue(pipe.get_active_adapters() == ["adapter-2"]) - # Check if the output is the same after lora loading error - lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3)) + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. - # This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the - # original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora - # weight is compatible with the current model inadequate. This should be addressed when attempting support for - # https://github.com/huggingface/diffusers/issues/10180 (TODO) + # This should raise a runtime error on input shapes being incompatible. components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) # Change the transformer config to mimic a real use case. num_channels_without_control = 4 @@ -521,14 +497,11 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self): "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } + pipe.load_lora_weights(lora_state_dict, "adapter-1") - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) self.assertTrue(pipe.transformer.config.in_channels == in_features) - self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) lora_state_dict = { "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, @@ -546,6 +519,98 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self): "adapter-2", ) + def test_fuse_expanded_lora_with_regular_lora(self): + # This test checks if it works when a lora with expanded shapes (like control loras) but + # another lora with correct shapes is loaded. The opposite direction isn't supported and is + # tested with it. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + components["transformer"] = transformer + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, + "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, + } + pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) + normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + pipe.load_lora_weights(lora_state_dict, "adapter-2") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0]) + lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3)) + + pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"]) + lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3)) + + def test_load_regular_lora(self): + # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded + # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those + # transformers include Flux Fill, Flux Control, etc. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA. + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) + normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.INFO) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) + self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3)) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass From bf6eaa8aec3f70d398015d5a2d43ea4984c78555 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 20 Dec 2024 16:14:58 +0530 Subject: [PATCH 61/88] [Tests] add integration tests for lora expansion stuff in Flux. (#10318) add integration tests for lora expansion stuff in Flux. --- tests/lora/test_lora_layers_flux.py | 37 +++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 1378c048b868..10ea2de5ef88 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -825,3 +825,40 @@ def test_lora(self, lora_ckpt_id): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 1e-3 + + @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) + def test_lora_with_turbo(self, lora_ckpt_id): + self.pipeline.load_lora_weights(lora_ckpt_id) + self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors") + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + + if "Canny" in lora_ckpt_id: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" + ) + else: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" + ) + + image = self.pipeline( + prompt=self.prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=self.num_inference_steps, + guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = image[0, -3:, -3:, -1].flatten() + if "Canny" in lora_ckpt_id: + expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484]) + else: + expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 From e12d610faacfe69f2de28b5a6e67fcd1501367b2 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 20 Dec 2024 16:27:38 +0530 Subject: [PATCH 62/88] Mochi docs (#9934) * update * update * update * update * update --------- Co-authored-by: Sayak Paul --- docs/source/en/api/pipelines/mochi.md | 197 +++++++++++++++++++++++++- 1 file changed, 196 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/mochi.md b/docs/source/en/api/pipelines/mochi.md index f29297e5901c..4da53a53662e 100644 --- a/docs/source/en/api/pipelines/mochi.md +++ b/docs/source/en/api/pipelines/mochi.md @@ -13,7 +13,7 @@ # limitations under the License. --> -# Mochi +# Mochi 1 Preview [Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo. @@ -25,6 +25,201 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
+## Generating videos with Mochi-1 Preview + +The following example will download the full precision `mochi-1-preview` weights and produce the highest quality results but will require at least 42GB VRAM to run. + +```python +import torch +from diffusers import MochiPipeline +from diffusers.utils import export_to_video + +pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview") + +# Enable memory savings +pipe.enable_model_cpu_offload() +pipe.enable_vae_tiling() + +prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." + +with torch.autocast("cuda", torch.bfloat16, cache_enabled=False): + frames = pipe(prompt, num_frames=85).frames[0] + +export_to_video(frames, "mochi.mp4", fps=30) +``` + +## Using a lower precision variant to save memory + +The following example will use the `bfloat16` variant of the model and requires 22GB VRAM to run. There is a slight drop in the quality of the generated video as a result. + +```python +import torch +from diffusers import MochiPipeline +from diffusers.utils import export_to_video + +pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16) + +# Enable memory savings +pipe.enable_model_cpu_offload() +pipe.enable_vae_tiling() + +prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." +frames = pipe(prompt, num_frames=85).frames[0] + +export_to_video(frames, "mochi.mp4", fps=30) +``` + +## Reproducing the results from the Genmo Mochi repo + +The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the the original implementation, please refer to the following example. + + +The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder. + +When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision. + + + +Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`. + + +```python +import torch +from torch.nn.attention import SDPBackend, sdpa_kernel + +from diffusers import MochiPipeline +from diffusers.utils import export_to_video +from diffusers.video_processor import VideoProcessor + +pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True) +pipe.enable_vae_tiling() +pipe.enable_model_cpu_offload() + +prompt = "An aerial shot of a parade of elephants walking across the African savannah. The camera showcases the herd and the surrounding landscape." + +with torch.no_grad(): + prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( + pipe.encode_prompt(prompt=prompt) + ) + +with torch.autocast("cuda", torch.bfloat16): + with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): + frames = pipe( + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_attention_mask=negative_prompt_attention_mask, + guidance_scale=4.5, + num_inference_steps=64, + height=480, + width=848, + num_frames=163, + generator=torch.Generator("cuda").manual_seed(0), + output_type="latent", + return_dict=False, + )[0] + +video_processor = VideoProcessor(vae_scale_factor=8) +has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None +has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None +if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype) + ) + latents_std = ( + torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype) + ) + frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean +else: + frames = frames / pipe.vae.config.scaling_factor + +with torch.no_grad(): + video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0] + +video = video_processor.postprocess_video(video)[0] +export_to_video(video, "mochi.mp4", fps=30) +``` + +## Running inference with multiple GPUs + +It is possible to split the large Mochi transformer across multiple GPUs using the `device_map` and `max_memory` options in `from_pretrained`. In the following example we split the model across two GPUs, each with 24GB of VRAM. + +```python +import torch +from diffusers import MochiPipeline, MochiTransformer3DModel +from diffusers.utils import export_to_video + +model_id = "genmo/mochi-1-preview" +transformer = MochiTransformer3DModel.from_pretrained( + model_id, + subfolder="transformer", + device_map="auto", + max_memory={0: "24GB", 1: "24GB"} +) + +pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer) +pipe.enable_model_cpu_offload() +pipe.enable_vae_tiling() + +with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False): + frames = pipe( + prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.", + negative_prompt="", + height=480, + width=848, + num_frames=85, + num_inference_steps=50, + guidance_scale=4.5, + num_videos_per_prompt=1, + generator=torch.Generator(device="cuda").manual_seed(0), + max_sequence_length=256, + output_type="pil", + ).frames[0] + +export_to_video(frames, "output.mp4", fps=30) +``` + +## Using single file loading with the Mochi Transformer + +You can use `from_single_file` to load the Mochi transformer in its original format. + + +Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints. + + +```python +import torch +from diffusers import MochiPipeline, MochiTransformer3DModel +from diffusers.utils import export_to_video + +model_id = "genmo/mochi-1-preview" + +ckpt_path = "https://huggingface.co/Comfy-Org/mochi_preview_repackaged/blob/main/split_files/diffusion_models/mochi_preview_bf16.safetensors" + +transformer = MochiTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16) + +pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer) +pipe.enable_model_cpu_offload() +pipe.enable_vae_tiling() + +with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False): + frames = pipe( + prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.", + negative_prompt="", + height=480, + width=848, + num_frames=85, + num_inference_steps=50, + guidance_scale=4.5, + num_videos_per_prompt=1, + generator=torch.Generator(device="cuda").manual_seed(0), + max_sequence_length=256, + output_type="pil", + ).frames[0] + +export_to_video(frames, "output.mp4", fps=30) +``` + ## MochiPipeline [[autodoc]] MochiPipeline From b64ca6c11cbc1644c22f1dae441c8124d588bb14 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 20 Dec 2024 18:32:22 +0530 Subject: [PATCH 63/88] [Docs] Update ltx_video.md to remove generator from `from_pretrained()` (#10316) Update ltx_video.md to remove generator from `from_pretrained()` --- docs/source/en/api/pipelines/ltx_video.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index 211cd3007d1e..a925b848706e 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -79,7 +79,6 @@ transformer = LTXVideoTransformer3DModel.from_single_file( pipe = LTXPipeline.from_pretrained( "Lightricks/LTX-Video", transformer=transformer, - generator=torch.manual_seed(0), torch_dtype=torch.bfloat16, ) pipe.enable_model_cpu_offload() From c8ee4af22843faa4fe79f24747012c8f133894e4 Mon Sep 17 00:00:00 2001 From: Leojc Date: Fri, 20 Dec 2024 23:22:32 +0800 Subject: [PATCH 64/88] docs: fix a mistake in docstring (#10319) Update pipeline_hunyuan_video.py docs: fix a mistake --- src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 4423ccf97932..6e0541e938ba 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -143,7 +143,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): Args: text_encoder ([`LlamaModel`]): [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). - tokenizer_2 (`LlamaTokenizer`): + tokenizer (`LlamaTokenizer`): Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). transformer ([`HunyuanVideoTransformer3DModel`]): Conditional Transformer to denoise the encoded image latents. From 902008608ad5ab687056b38d5b4c35284228fd88 Mon Sep 17 00:00:00 2001 From: Aditya Raj Date: Fri, 20 Dec 2024 20:59:58 +0530 Subject: [PATCH 65/88] [BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError in function prepare_latents caused by audio_vae_length (#10306) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float" torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor. in function prepare_latents: audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) ... audio = initial_audio_waveforms.new_zeros(audio_shape) audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float Co-authored-by: hlky --- src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py index cef63cf7e63d..5d773b614a5c 100644 --- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py +++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -446,7 +446,7 @@ def prepare_latents( f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions" ) - audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length + audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) # check num_channels From 7d4db57037b9504c240078768ce95ff6588a92bd Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Fri, 20 Dec 2024 08:30:21 -0800 Subject: [PATCH 66/88] [docs] Fix quantization links (#10323) Update overview.md --- docs/source/en/quantization/overview.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 3eef5238f1ce..794098e210a6 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -33,8 +33,8 @@ If you are new to the quantization field, we recommend you to check out these be ## When to use what? Diffusers currently supports the following quantization methods. -- [BitsandBytes](./bitsandbytes.md) -- [TorchAO](./torchao.md) -- [GGUF](./gguf.md) +- [BitsandBytes](./bitsandbytes) +- [TorchAO](./torchao) +- [GGUF](./gguf) [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. From a6288a5571dbc63a03dc761a4d5300fcec61a04b Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Sat, 21 Dec 2024 01:21:34 +0800 Subject: [PATCH 67/88] [Sana]add 2K related model for Sana (#10322) add 2K related model for Sana --- scripts/convert_sana_to_diffusers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index c1045a98a51a..dc553681678b 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -25,6 +25,7 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext ckpt_ids = [ + "Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth", "Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth", "Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth", @@ -265,9 +266,9 @@ def main(args): "--image_size", default=1024, type=int, - choices=[512, 1024], + choices=[512, 1024, 2048], required=False, - help="Image size of pretrained model, 512 or 1024.", + help="Image size of pretrained model, 512, 1024 or 2048.", ) parser.add_argument( "--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"] From d41388145e7fa7fac5e75047bcbd19eb9276cb64 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 21 Dec 2024 07:15:03 +0530 Subject: [PATCH 68/88] [Docs] Update gguf.md to remove generator from the pipeline from_pretrained (#10299) Update gguf.md to remove generator from the pipeline from_pretrained --- docs/source/en/quantization/gguf.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md index 2ff2a9293130..f7537d7e7882 100644 --- a/docs/source/en/quantization/gguf.md +++ b/docs/source/en/quantization/gguf.md @@ -45,12 +45,11 @@ transformer = FluxTransformer2DModel.from_single_file( pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", transformer=transformer, - generator=torch.manual_seed(0), torch_dtype=torch.bfloat16, ) pipe.enable_model_cpu_offload() prompt = "A cat holding a sign that says hello world" -image = pipe(prompt).images[0] +image = pipe(prompt, generator=torch.manual_seed(0)).images[0] image.save("flux-gguf.png") ``` From a756694bf0f4d2a1bba770586bcb7670235d296a Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 21 Dec 2024 14:10:32 +0000 Subject: [PATCH 69/88] Fix push_tests_mps.yml (#10326) --- .github/workflows/push_tests_mps.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml index 8d521074a08f..5fd3b78be7df 100644 --- a/.github/workflows/push_tests_mps.yml +++ b/.github/workflows/push_tests_mps.yml @@ -46,7 +46,7 @@ jobs: shell: arch -arch arm64 bash {0} run: | ${CONDA_RUN} python -m pip install --upgrade pip uv - ${CONDA_RUN} python -m uv pip install -e [quality,test] + ${CONDA_RUN} python -m uv pip install -e ".[quality,test]" ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git ${CONDA_RUN} python -m uv pip install transformers --upgrade From bf9a641f1a51368af5f3ae99cc460107d4fa2103 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 21 Dec 2024 14:10:44 +0000 Subject: [PATCH 70/88] Fix EMAModel test_from_pretrained (#10325) --- tests/others/test_ema.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 3443e6366f01..7cf8f30ecc44 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -67,6 +67,7 @@ def test_from_pretrained(self): # Load the EMA model from the saved directory loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False) + loaded_ema_unet.to(torch_device) # Check that the shadow parameters of the loaded model match the original EMA model for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): @@ -221,6 +222,7 @@ def test_from_pretrained(self): # Load the EMA model from the saved directory loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True) + loaded_ema_unet.to(torch_device) # Check that the shadow parameters of the loaded model match the original EMA model for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): From be2070991f1b916977020c45ecdfec225de21862 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 21 Dec 2024 17:49:58 +0000 Subject: [PATCH 71/88] Support Flux IP Adapter (#10261) * Flux IP-Adapter * test cfg * make style * temp remove copied from * fix test * fix test * v2 * fix * make style * temp remove copied from * Apply suggestions from code review Co-authored-by: YiYi Xu * Move encoder_hid_proj to inside FluxTransformer2DModel * merge * separate encode_prompt, add copied from, image_encoder offload * make * fix test * fix * Update src/diffusers/pipelines/flux/pipeline_flux.py * test_flux_prompt_embeds change not needed * true_cfg -> true_cfg_scale * fix merge conflict * test_flux_ip_adapter_inference * add fast test * FluxIPAdapterMixin not test mixin * Update pipeline_flux.py Co-authored-by: YiYi Xu --------- Co-authored-by: YiYi Xu --- ...nvert_flux_xlabs_ipadapter_to_diffusers.py | 97 ++++++ src/diffusers/loaders/__init__.py | 5 +- src/diffusers/loaders/ip_adapter.py | 286 ++++++++++++++++++ src/diffusers/loaders/transformer_flux.py | 179 +++++++++++ src/diffusers/models/attention_processor.py | 146 ++++++++- src/diffusers/models/embeddings.py | 2 +- .../models/transformers/transformer_flux.py | 20 +- src/diffusers/pipelines/flux/pipeline_flux.py | 178 ++++++++++- .../pipelines/flux/pipeline_flux_control.py | 1 - .../test_models_transformer_flux.py | 52 ++++ tests/pipelines/flux/test_pipeline_flux.py | 114 ++++++- tests/pipelines/test_pipelines_common.py | 91 +++++- 12 files changed, 1157 insertions(+), 14 deletions(-) create mode 100644 scripts/convert_flux_xlabs_ipadapter_to_diffusers.py create mode 100644 src/diffusers/loaders/transformer_flux.py diff --git a/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py b/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py new file mode 100644 index 000000000000..b701b7fb40b1 --- /dev/null +++ b/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py @@ -0,0 +1,97 @@ +import argparse +from contextlib import nullcontext + +import safetensors.torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available + + +if is_transformers_available(): + from transformers import CLIPVisionModelWithProjection + + vision = True +else: + vision = False + +""" +python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \ +--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \ +--filename "flux-ip-adapter.safetensors" +--output_path "flux-ip-adapter-hf/" +""" + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_state_dict_repo_id", default=None, type=str) +parser.add_argument("--filename", default="flux.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", type=str) +parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str) + +args = parser.parse_args() + + +def load_original_checkpoint(args): + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers): + converted_state_dict = {} + + # image_proj + ## norm + converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight") + converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias") + ## proj + converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight") + converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"ip_adapter.{i}." + # to_k_ip + converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop( + f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias" + ) + converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop( + f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight" + ) + # to_v_ip + converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop( + f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias" + ) + converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop( + f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight" + ) + + return converted_state_dict + + +def main(args): + original_ckpt = load_original_checkpoint(args) + + num_layers = 19 + converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers) + + print("Saving Flux IP-Adapter in Diffusers format.") + safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors") + + if vision: + model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path) + model.save_pretrained(f"{args.output_path}/image_encoder") + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index c7ea0be55db2..2db8b53db498 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] - + _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"] _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] @@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder): _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = [ "IPAdapterMixin", + "FluxIPAdapterMixin", "SD3IPAdapterMixin", ] @@ -86,12 +87,14 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin + from .transformer_flux import FluxTransformer2DLoadersMixin from .transformer_sd3 import SD3Transformer2DLoadersMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers if is_transformers_available(): from .ip_adapter import ( + FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin, ) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 11ce4f1634d7..7b691d1fe16e 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -38,6 +38,8 @@ from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, + FluxAttnProcessor2_0, + FluxIPAdapterJointAttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, @@ -353,6 +355,290 @@ def unload_ip_adapter(self): self.unet.set_attn_processor(attn_procs) +class FluxIPAdapterMixin: + """Mixin for handling Flux IP Adapters.""" + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], + weight_name: Union[str, List[str]], + subfolder: Optional[Union[str, List[str]]] = "", + image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder", + image_encoder_subfolder: Optional[str] = "", + image_encoder_dtype: torch.dtype = torch.float16, + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + subfolder (`str` or `List[str]`): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + weight_name (`str` or `List[str]`): + The name of the weight file to load. If a list is passed, it should have the same length as + `weight_name`. + image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`): + Can be either: + + - A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model + hosted on the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + + # handle the list inputs for multiple IP Adapters + if not isinstance(weight_name, list): + weight_name = [weight_name] + + if not isinstance(pretrained_model_name_or_path_or_dict, list): + pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] + if len(pretrained_model_name_or_path_or_dict) == 1: + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) + + if not isinstance(subfolder, list): + subfolder = [subfolder] + if len(subfolder) == 1: + subfolder = subfolder * len(weight_name) + + if len(weight_name) != len(pretrained_model_name_or_path_or_dict): + raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") + + if len(weight_name) != len(subfolder): + raise ValueError("`weight_name` and `subfolder` must have the same length.") + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + state_dicts = [] + for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( + pretrained_model_name_or_path_or_dict, weight_name, subfolder + ): + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + image_proj_keys = ["ip_adapter_proj_model.", "image_proj."] + ip_adapter_keys = ["double_blocks.", "ip_adapter."] + for key in f.keys(): + if any(key.startswith(prefix) for prefix in image_proj_keys): + diffusers_name = ".".join(key.split(".")[1:]) + state_dict["image_proj"][diffusers_name] = f.get_tensor(key) + elif any(key.startswith(prefix) for prefix in ip_adapter_keys): + diffusers_name = ( + ".".join(key.split(".")[1:]) + .replace("ip_adapter_double_stream_k_proj", "to_k_ip") + .replace("ip_adapter_double_stream_v_proj", "to_v_ip") + .replace("processor.", "") + ) + state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if keys != ["image_proj", "ip_adapter"]: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + state_dicts.append(state_dict) + + # load CLIP image encoder here if it has not been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if image_encoder_pretrained_model_name_or_path is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}") + image_encoder = ( + CLIPVisionModelWithProjection.from_pretrained( + image_encoder_pretrained_model_name_or_path, + subfolder=image_encoder_subfolder, + low_cpu_mem_usage=low_cpu_mem_usage, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + .to(self.device, dtype=image_encoder_dtype) + .eval() + ) + self.register_modules(image_encoder=image_encoder) + else: + raise ValueError( + "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." + ) + else: + logger.warning( + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + ) + + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: + # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224 + default_clip_size = 224 + clip_image_size = ( + self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size + ) + feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) + self.register_modules(feature_extractor=feature_extractor) + + # load ip-adapter into transformer + self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + + def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]): + """ + Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for + granular control over each IP-Adapter behavior. A config can be a float or a list. + + `float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]` + length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the + number of IP adapters and each must match the number of blocks. + + Example: + + ```py + # To use original IP-Adapter + scale = 1.0 + pipeline.set_ip_adapter_scale(scale) + + + def LinearStrengthModel(start, finish, size): + return [(start + (finish - start) * (i / (size - 1))) for i in range(size)] + + + ip_strengths = LinearStrengthModel(0.3, 0.92, 19) + pipeline.set_ip_adapter_scale(ip_strengths) + ``` + """ + transformer = self.transformer + if not isinstance(scale, list): + scale = [[scale] * transformer.config.num_layers] + elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float): + if len(scale) != transformer.config.num_layers: + raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.") + scale = [scale] + + scale_configs = scale + + key_id = 0 + for attn_name, attn_processor in transformer.attn_processors.items(): + if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)): + if len(scale_configs) != len(attn_processor.scale): + raise ValueError( + f"Cannot assign {len(scale_configs)} scale_configs to " + f"{len(attn_processor.scale)} IP-Adapter." + ) + elif len(scale_configs) == 1: + scale_configs = scale_configs * len(attn_processor.scale) + for i, scale_config in enumerate(scale_configs): + attn_processor.scale[i] = scale_config[key_id] + key_id += 1 + + def unload_ip_adapter(self): + """ + Unloads the IP Adapter weights + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # remove CLIP image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=[None, None]) + + # remove feature extractor only when safety_checker is None as safety_checker uses + # the feature_extractor later + if not hasattr(self, "safety_checker"): + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=[None, None]) + + # remove hidden encoder + self.transformer.encoder_hid_proj = None + self.transformer.config.encoder_hid_dim_type = None + + # restore original Transformer attention processors layers + attn_procs = {} + for name, value in self.transformer.attn_processors.items(): + attn_processor_class = FluxAttnProcessor2_0() + attn_procs[name] = ( + attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__() + ) + self.transformer.set_attn_processor(attn_procs) + + class SD3IPAdapterMixin: """Mixin for handling StableDiffusion 3 IP Adapters.""" diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py new file mode 100644 index 000000000000..52a48e56e748 --- /dev/null +++ b/src/diffusers/loaders/transformer_flux.py @@ -0,0 +1,179 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext + +from ..models.embeddings import ( + ImageProjection, + MultiIPAdapterImageProjection, +) +from ..models.modeling_utils import load_model_dict_into_meta +from ..utils import ( + is_accelerate_available, + is_torch_version, + logging, +) + + +if is_accelerate_available(): + pass + +logger = logging.get_logger(__name__) + + +class FluxTransformer2DLoadersMixin: + """ + Load layers into a [`FluxTransformer2DModel`]. + """ + + def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False): + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + updated_state_dict = {} + image_projection = None + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + + if "proj.weight" in state_dict: + # IP-Adapter + num_image_text_embeds = 4 + if state_dict["proj.weight"].shape[0] == 65536: + num_image_text_embeds = 16 + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds + + with init_context(): + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj", "image_embeds") + updated_state_dict[diffusers_name] = value + + if not low_cpu_mem_usage: + image_projection.load_state_dict(updated_state_dict, strict=True) + else: + load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + + return image_projection + + def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): + from ..models.attention_processor import ( + FluxIPAdapterJointAttnProcessor2_0, + ) + + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + # set ip-adapter cross-attention processors & load state_dict + attn_procs = {} + key_id = 0 + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + for name in self.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + attn_processor_class = self.attn_processors[name].__class__ + attn_procs[name] = attn_processor_class() + else: + cross_attention_dim = self.config.joint_attention_dim + hidden_size = self.inner_dim + attn_processor_class = FluxIPAdapterJointAttnProcessor2_0 + num_image_text_embeds = [] + for state_dict in state_dicts: + if "proj.weight" in state_dict["image_proj"]: + num_image_text_embed = 4 + if state_dict["image_proj"]["proj.weight"].shape[0] == 65536: + num_image_text_embed = 16 + # IP-Adapter + num_image_text_embeds += [num_image_text_embed] + + with init_context(): + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, + dtype=self.dtype, + device=self.device, + ) + + value_dict = {} + for i, state_dict in enumerate(state_dicts): + value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) + value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) + value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]}) + value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]}) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(value_dict) + else: + device = self.device + dtype = self.dtype + load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) + + key_id += 1 + + return attn_procs + + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): + if not isinstance(state_dicts, list): + state_dicts = [state_dicts] + + self.encoder_hid_proj = None + + attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + self.set_attn_processor(attn_procs) + + image_projection_layers = [] + for state_dict in state_dicts: + image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( + state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage + ) + image_projection_layers.append(image_projection_layer) + + self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ed0dd4f71d27..6e1dc1037c20 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -575,7 +575,7 @@ def forward( # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks"} + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] @@ -2653,6 +2653,149 @@ def __call__( return hidden_states +class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module): + """Flux Attention processor for IP-Adapter.""" + + def __init__( + self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + self.to_v_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ip_hidden_states: Optional[List[torch.Tensor]] = None, + ip_adapter_masks: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + hidden_states_query_proj = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP-adapter + ip_query = hidden_states_query_proj + ip_attn_output = None + # for ip-adapter + # TODO: support for multiple adapters + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_attn_output = F.scaled_dot_product_attention( + ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_attn_output = scale * ip_attn_output + ip_attn_output = ip_attn_output.to(ip_query.dtype) + + return hidden_states, encoder_hidden_states, ip_attn_output + else: + return hidden_states + + class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -5896,6 +6039,7 @@ def __call__( SlicedAttnProcessor, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, + FluxIPAdapterJointAttnProcessor2_0, ) AttentionProcessor = Union[ diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index f1b339e6180b..4558d48edad9 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1535,7 +1535,7 @@ def forward(self, image_embeds: torch.Tensor): batch_size = image_embeds.shape[0] # image - image_embeds = self.image_embeds(image_embeds) + image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype)) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) image_embeds = self.norm(image_embeds) return image_embeds diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 8dbe49b75076..dc2eb26f9d30 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import FeedForward from ...models.attention_processor import ( Attention, @@ -177,13 +177,18 @@ def forward( ) joint_attention_kwargs = joint_attention_kwargs or {} # Attention. - attn_output, context_attn_output = self.attn( + attention_outputs = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + # Process attention outputs for the `hidden_states`. attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output @@ -195,6 +200,8 @@ def forward( ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output # Process attention outputs for the `encoder_hidden_states`. @@ -212,7 +219,9 @@ def forward( return encoder_hidden_states, hidden_states -class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class FluxTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin +): """ The Transformer model introduced in Flux. @@ -482,6 +491,11 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index ec2801625552..181f0269ce3e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -17,10 +17,17 @@ import numpy as np import torch -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) -from ...image_processor import VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -142,6 +149,7 @@ class FluxPipeline( FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, + FluxIPAdapterMixin, ): r""" The Flux pipeline for text-to-image generation. @@ -169,8 +177,8 @@ class FluxPipeline( [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = [] + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( @@ -182,6 +190,8 @@ def __init__( text_encoder_2: T5EncoderModel, tokenizer_2: T5TokenizerFast, transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, ): super().__init__() @@ -193,6 +203,8 @@ def __init__( tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -377,14 +389,60 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers + ): + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + + image_embeds.append(single_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + def check_inputs( self, prompt, prompt_2, height, width, + negative_prompt=None, + negative_prompt_2=None, prompt_embeds=None, + negative_prompt_embeds=None, pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -419,10 +477,33 @@ def check_inputs( elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @@ -551,6 +632,9 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, @@ -561,6 +645,12 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -610,6 +700,17 @@ def __call__( pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -647,8 +748,12 @@ def __call__( prompt_2, height, width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -670,6 +775,7 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) + do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None ( prompt_embeds, pooled_prompt_embeds, @@ -684,6 +790,21 @@ def __call__( max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -725,12 +846,43 @@ def __call__( else: guidance = None + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -746,6 +898,22 @@ def __call__( return_dict=False, )[0] + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index dc3ca8cf7b09..ac8474becb78 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -403,7 +403,6 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs def check_inputs( self, prompt, diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 4a784eee4732..c88b3dac8216 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -18,6 +18,8 @@ import torch from diffusers import FluxTransformer2DModel +from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 +from diffusers.models.embeddings import ImageProjection from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin @@ -26,6 +28,56 @@ enable_full_determinism() +def create_flux_ip_adapter_state_dict(model): + # "ip_adapter" (cross-attention weights) + ip_cross_attn_state_dict = {} + key_id = 0 + + for name in model.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + continue + + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + sd = FluxIPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], + } + ) + + key_id += 1 + + # "image_proj" (ImageProjection layer weights) + + image_projection = ImageProjection( + cross_attention_dim=model.config["joint_attention_dim"], + image_embed_dim=model.config["pooled_projection_dim"], + num_image_text_embeds=4, + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel main_input_name = "hidden_states" diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index df9021ee0adb..7981e6c2a93b 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -16,13 +16,14 @@ ) from ..test_pipelines_common import ( + FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, ) -class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): +class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -91,6 +92,8 @@ def get_dummy_components(self): "tokenizer_2": tokenizer_2, "transformer": transformer, "vae": vae, + "image_encoder": None, + "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0): @@ -296,3 +299,112 @@ def test_flux_inference(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) assert max_diff < 1e-4 + + +@slow +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda +class FluxIPAdapterPipelineSlowTests(unittest.TestCase): + pipeline_class = FluxPipeline + repo_id = "black-forest-labs/FLUX.1-dev" + image_encoder_pretrained_model_name_or_path = "openai/clip-vit-large-patch14" + weight_name = "ip_adapter.safetensors" + ip_adapter_repo_id = "XLabs-AI/flux-ip-adapter" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + prompt_embeds = torch.load( + hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt") + ) + pooled_prompt_embeds = torch.load( + hf_hub_download( + repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt" + ) + ) + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + ip_adapter_image = np.zeros((1024, 1024, 3), dtype=np.uint8) + return { + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_pooled_prompt_embeds": negative_pooled_prompt_embeds, + "ip_adapter_image": ip_adapter_image, + "num_inference_steps": 2, + "guidance_scale": 3.5, + "true_cfg_scale": 4.0, + "max_sequence_length": 256, + "output_type": "np", + "generator": generator, + } + + def test_flux_ip_adapter_inference(self): + pipe = self.pipeline_class.from_pretrained( + self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None + ) + pipe.load_ip_adapter( + self.ip_adapter_repo_id, + weight_name=self.weight_name, + image_encoder_pretrained_model_name_or_path=self.image_encoder_pretrained_model_name_or_path, + ) + pipe.set_ip_adapter_scale(1.0) + pipe.enable_model_cpu_offload() + + inputs = self.get_inputs(torch_device) + + image = pipe(**inputs).images[0] + image_slice = image[0, :10, :10] + + expected_slice = np.array( + [ + 0.1855, + 0.1680, + 0.1406, + 0.1953, + 0.1699, + 0.1465, + 0.2012, + 0.1738, + 0.1484, + 0.2051, + 0.1797, + 0.1523, + 0.2012, + 0.1719, + 0.1445, + 0.2070, + 0.1777, + 0.1465, + 0.2090, + 0.1836, + 0.1484, + 0.2129, + 0.1875, + 0.1523, + 0.2090, + 0.1816, + 0.1484, + 0.2110, + 0.1836, + 0.1543, + ], + dtype=np.float32, + ) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) + + assert max_diff < 1e-4, f"{image_slice} != {expected_slice}" diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 4d2b534c9a28..764be1890cc5 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -29,7 +29,7 @@ UNet2DConditionModel, ) from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import IPAdapterMixin +from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.models.attention_processor import AttnProcessor from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel @@ -54,6 +54,7 @@ get_autoencoder_tiny_config, get_consistency_vae_config, ) +from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict from ..models.unets.test_models_unet_2d_condition import ( create_ip_adapter_faceid_state_dict, create_ip_adapter_state_dict, @@ -483,6 +484,94 @@ def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4): ) +class FluxIPAdapterTesterMixin: + """ + This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. + It provides a set of common tests for pipelines that support IP Adapters. + """ + + def test_pipeline_signature(self): + parameters = inspect.signature(self.pipeline_class.__call__).parameters + + assert issubclass(self.pipeline_class, FluxIPAdapterMixin) + self.assertIn( + "ip_adapter_image", + parameters, + "`ip_adapter_image` argument must be supported by the `__call__` method", + ) + self.assertIn( + "ip_adapter_image_embeds", + parameters, + "`ip_adapter_image_embeds` argument must be supported by the `__call__` method", + ) + + def _get_dummy_image_embeds(self, image_embed_dim: int = 768): + return torch.randn((1, 1, image_embed_dim), device=torch_device) + + def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): + inputs["negative_prompt"] = "" + inputs["true_cfg_scale"] = 4.0 + inputs["output_type"] = "np" + inputs["return_dict"] = False + return inputs + + def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None): + r"""Tests for IP-Adapter. + + The following scenarios are tested: + - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + """ + # Raising the tolerance for this test when it's run on a CPU because we + # compare against static slices and that can be shaky (with a VVVV low probability). + expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + pipe.set_progress_bar_config(disable=None) + image_embed_dim = pipe.transformer.config.pooled_projection_dim + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + if expected_pipe_slice is None: + output_without_adapter = pipe(**inputs)[0] + else: + output_without_adapter = expected_pipe_slice + + adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer) + pipe.transformer._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] + inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs)[0] + if expected_pipe_slice is not None: + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] + inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs)[0] + if expected_pipe_slice is not None: + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" + ) + + class PipelineLatentTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. From 233dffdc3f56b26abaaba8363a5dd30dab7f0e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mehmet=20Yi=C4=9Fit=20=C3=96zgen=C3=A7?= <47952284+yigitozgenc@users.noreply.github.com> Date: Sat, 21 Dec 2024 21:44:43 +0300 Subject: [PATCH 72/88] flux controlnet inpaint config bug (#10291) * flux controlnet inpaint config bug * Update src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py --------- Co-authored-by: yigitozgenc Co-authored-by: hlky --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index c557cf134b05..85943b278dc6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1095,7 +1095,11 @@ def __call__( timestep = t.expand(latents.shape[0]).to(latents.dtype) # predict the noise residual - if self.controlnet.config.guidance_embeds: + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + if use_guidance: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) else: From 6aaa0518e3d1e8de2b1dc1368e0daa4d1044db94 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 06:56:28 +0530 Subject: [PATCH 73/88] Community hosted weights for diffusers format HunyuanVideo weights (#10344) update docs and example to use community weights --- docs/source/en/api/models/autoencoder_kl_hunyuan_video.md | 2 +- docs/source/en/api/models/hunyuan_video_transformer_3d.md | 2 +- docs/source/en/api/pipelines/hunyuan_video.md | 2 +- src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md index f69c14814d3d..33dff5b903cd 100644 --- a/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md +++ b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet. ```python from diffusers import AutoencoderKLHunyuanVideo -vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16) +vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16) ``` ## AutoencoderKLHunyuanVideo diff --git a/docs/source/en/api/models/hunyuan_video_transformer_3d.md b/docs/source/en/api/models/hunyuan_video_transformer_3d.md index 73aea9832fc0..522d0eb0479d 100644 --- a/docs/source/en/api/models/hunyuan_video_transformer_3d.md +++ b/docs/source/en/api/models/hunyuan_video_transformer_3d.md @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet. ```python from diffusers import HunyuanVideoTransformer3DModel -transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16) +transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16) ``` ## HunyuanVideoTransformer3DModel diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md index 86ef816fcd4d..0519340075cf 100644 --- a/docs/source/en/api/pipelines/hunyuan_video.md +++ b/docs/source/en/api/pipelines/hunyuan_video.md @@ -29,7 +29,7 @@ Recommendations for inference: - Transformer should be in `torch.bfloat16`. - VAE should be in `torch.float16`. - `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`. -- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo. +- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo. - For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/). ## HunyuanVideoPipeline diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 6e0541e938ba..3b0956a32da3 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -39,7 +39,7 @@ >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel >>> from diffusers.utils import export_to_video - >>> model_id = "tencent/HunyuanVideo" + >>> model_id = "hunyuanvideo-community/HunyuanVideo" >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ... ) From f615f00f58b73a216f9b31ea5247367d8f588ceb Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 23 Dec 2024 01:28:28 +0000 Subject: [PATCH 74/88] Fix enable_sequential_cpu_offload in test_kandinsky_combined (#10324) Co-authored-by: Sayak Paul --- .../kandinsky/pipeline_kandinsky_combined.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py index fe9909770376..e653b8266f19 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py @@ -193,15 +193,15 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id=0): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. """ - self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) - self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) @@ -411,7 +411,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id=0): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -419,8 +419,8 @@ def enable_sequential_cpu_offload(self, gpu_id=0): Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ - self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) - self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) @@ -652,7 +652,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id=0): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -660,8 +660,8 @@ def enable_sequential_cpu_offload(self, gpu_id=0): Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ - self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) - self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) From 7c2f0afb1c0ff4dbfb8daeed8cef65074651c92a Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 22 Dec 2024 16:44:13 -1000 Subject: [PATCH 75/88] update `get_parameter_dtype` (#10342) add: q --- src/diffusers/models/modeling_utils.py | 48 ++++++++++++++++++-------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 872d4d73d41f..d236ebb83983 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -99,21 +99,39 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: - try: - return next(parameter.parameters()).dtype - except StopIteration: - try: - return next(parameter.buffers()).dtype - except StopIteration: - # For torch.nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].dtype + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + last_dtype = None + for param in parameter.parameters(): + last_dtype = param.dtype + if param.is_floating_point(): + return param.dtype + + for buffer in parameter.buffers(): + last_dtype = buffer.dtype + if buffer.is_floating_point(): + return buffer.dtype + + if last_dtype is not None: + # if no floating dtype was found return whatever the first dtype is + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + last_tuple = None + for tuple in gen: + last_tuple = tuple + if tuple[1].is_floating_point(): + return tuple[1].dtype + + if last_tuple is not None: + # fallback to the last dtype + return last_tuple[1].dtype class ModelMixin(torch.nn.Module, PushToHubMixin): From da21d590b51a7e71d7a70a349300e09179b52e75 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 23 Dec 2024 08:44:58 +0530 Subject: [PATCH 76/88] [Single File] Add Single File support for HunYuan video (#10320) * update * Update src/diffusers/loaders/single_file_utils.py Co-authored-by: Aryan --------- Co-authored-by: Aryan --- src/diffusers/loaders/single_file_model.py | 8 +- src/diffusers/loaders/single_file_utils.py | 135 ++++++++++++++++++ .../transformers/transformer_hunyuan_video.py | 4 +- 3 files changed, 145 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index d102282025c7..79dc2691b9e4 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -28,6 +28,7 @@ convert_autoencoder_dc_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_flux_transformer_checkpoint_to_diffusers, + convert_hunyuan_video_transformer_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_ltx_transformer_checkpoint_to_diffusers, @@ -101,6 +102,10 @@ "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "HunyuanVideoTransformer3DModel": { + "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers, + "default_subfolder": "transformer", + }, } @@ -220,6 +225,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = local_files_only = kwargs.pop("local_files_only", None) subfolder = kwargs.pop("subfolder", None) revision = kwargs.pop("revision", None) + config_revision = kwargs.pop("config_revision", None) torch_dtype = kwargs.pop("torch_dtype", None) quantization_config = kwargs.pop("quantization_config", None) device = kwargs.pop("device", None) @@ -297,7 +303,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = subfolder=subfolder, local_files_only=local_files_only, token=token, - revision=revision, + revision=config_revision, ) expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index f1408c2c409b..5933c634f4cc 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -108,6 +108,7 @@ "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", "autoencoder-dc-sana": "encoder.project_in.conv.bias", "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], + "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -162,6 +163,7 @@ "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, + "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, } # Use to configure model sample size when original config is provided @@ -624,6 +626,9 @@ def infer_diffusers_model_type(checkpoint): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): model_type = "mochi-1-preview" + if CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: + model_type = "hunyuan-video" + else: model_type = "v1" @@ -2522,3 +2527,133 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") return new_state_dict + + +def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs): + def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + def remap_txt_in_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3072 + + if "linear1.weight" in key: + linear1_weight = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") + state_dict[f"{new_key}.attn.to_q.weight"] = q + state_dict[f"{new_key}.attn.to_k.weight"] = k + state_dict[f"{new_key}.attn.to_v.weight"] = v + state_dict[f"{new_key}.proj_mlp.weight"] = mlp + + elif "linear1.bias" in key: + linear1_bias = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", + "img_mod.linear": "norm1.linear", + "img_norm1": "norm1.norm", + "img_norm2": "norm2", + "img_mlp": "ff", + "txt_mod.linear": "norm1_context.linear", + "txt_norm1": "norm1.norm", + "txt_norm2": "norm2_context", + "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", + "modulation.linear": "norm.linear", + "pre_norm": "norm.norm", + "final_layer.norm_final": "norm_out.norm", + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2", + "input_embedder": "proj_in", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, + "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, + } + + def update_state_dict_(state_dict, old_key, new_key): + state_dict[new_key] = state_dict.pop(old_key) + + for key in list(checkpoint.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(checkpoint, key, new_key) + + for key in list(checkpoint.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, checkpoint) + + return checkpoint diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 089389b5f9ad..e3f24d97f3fa 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -18,6 +18,8 @@ import torch.nn as nn import torch.nn.functional as F +from diffusers.loaders import FromOriginalModelMixin + from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers @@ -500,7 +502,7 @@ def forward( return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). From b58868e6f4781dc3b2c2b7ad6617d430e7e41a87 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Mon, 23 Dec 2024 11:26:25 +0800 Subject: [PATCH 77/88] [Sana bug] bug fix for 2K model config (#10340) * fix the Positinoal Embedding bug in 2K model; * Change the default model to the BF16 one for more stable training and output * make style * substract buffer size * add compute_module_persistent_sizes --------- Co-authored-by: yiyixuxu --- .../en/api/models/sana_transformer2d.md | 2 +- docs/source/en/api/pipelines/sana.md | 2 +- scripts/convert_sana_to_diffusers.py | 6 ++ .../models/transformers/sana_transformer.py | 5 +- .../pipelines/pag/pipeline_pag_sana.py | 4 +- src/diffusers/pipelines/sana/pipeline_sana.py | 4 +- tests/models/test_modeling_common.py | 88 ++++++++++++++++--- 7 files changed, 93 insertions(+), 18 deletions(-) diff --git a/docs/source/en/api/models/sana_transformer2d.md b/docs/source/en/api/models/sana_transformer2d.md index fd56d028818f..269aefd7ff69 100644 --- a/docs/source/en/api/models/sana_transformer2d.md +++ b/docs/source/en/api/models/sana_transformer2d.md @@ -22,7 +22,7 @@ The model can be loaded with the following code snippet. ```python from diffusers import SanaTransformer2DModel -transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16) +transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) ``` ## SanaTransformer2DModel diff --git a/docs/source/en/api/pipelines/sana.md b/docs/source/en/api/pipelines/sana.md index 64acb44962e6..d027a6cbf1f5 100644 --- a/docs/source/en/api/pipelines/sana.md +++ b/docs/source/en/api/pipelines/sana.md @@ -32,9 +32,9 @@ Available models: | Model | Recommended dtype | |:-----:|:-----------------:| +| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` | | [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` | | [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` | -| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` | | [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` | | [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` | | [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` | diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index dc553681678b..2f1732817be3 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -88,13 +88,18 @@ def main(args): # y norm converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight") + # scheduler flow_shift = 3.0 + + # model config if args.model_type == "SanaMS_1600M_P1_D20": layer_num = 20 elif args.model_type == "SanaMS_600M_P1_D28": layer_num = 28 else: raise ValueError(f"{args.model_type} is not supported.") + # Positional embedding interpolation scale. + interpolation_scale = {512: None, 1024: None, 2048: 1.0} for depth in range(layer_num): # Transformer blocks. @@ -176,6 +181,7 @@ def main(args): patch_size=1, norm_elementwise_affine=False, norm_eps=1e-6, + interpolation_scale=interpolation_scale[args.image_size], ) if is_accelerate_available(): diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 41224e42d2a5..027ab5fecefd 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -242,6 +242,7 @@ def __init__( patch_size: int = 1, norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, + interpolation_scale: Optional[int] = None, ) -> None: super().__init__() @@ -249,14 +250,14 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim # 1. Patch Embedding + interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1) self.patch_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, - interpolation_scale=None, - pos_embed_type=None, + interpolation_scale=interpolation_scale, ) # 2. Additional condition embeddings diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index cf4d41fee487..03662bb37158 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -59,13 +59,13 @@ >>> from diffusers import SanaPAGPipeline >>> pipe = SanaPAGPipeline.from_pretrained( - ... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", + ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", ... pag_applied_layers=["transformer_blocks.8"], ... torch_dtype=torch.float32, ... ) >>> pipe.to("cuda") >>> pipe.text_encoder.to(torch.bfloat16) - >>> pipe.transformer = pipe.transformer.to(torch.float16) + >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] >>> image[0].save("output.png") diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 2df6586d0bc4..fe3c9e13aa31 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -62,11 +62,11 @@ >>> from diffusers import SanaPipeline >>> pipe = SanaPipeline.from_pretrained( - ... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32 + ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32 ... ) >>> pipe.to("cuda") >>> pipe.text_encoder.to(torch.bfloat16) - >>> pipe.transformer = pipe.transformer.to(torch.float16) + >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] >>> image[0].save("output.png") diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 91a462d5878e..4fc14804475a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -22,12 +22,14 @@ import unittest import unittest.mock as mock import uuid -from typing import Dict, List, Tuple +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union import numpy as np import requests_mock import torch -from accelerate.utils import compute_module_sizes +import torch.nn as nn +from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size from huggingface_hub import ModelCard, delete_repo, snapshot_download from huggingface_hub.utils import is_jinja_available from parameterized import parameterized @@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): out_queue.join() +def named_persistent_module_tensors( + module: nn.Module, + recurse: bool = False, +): + """ + A helper function that gathers all the tensors (parameters + persistent buffers) of a given module. + + Args: + module (`torch.nn.Module`): + The module we want the tensors on. + recurse (`bool`, *optional`, defaults to `False`): + Whether or not to go look in every submodule or just return the direct parameters and buffers. + """ + yield from module.named_parameters(recurse=recurse) + + for named_buffer in module.named_buffers(recurse=recurse): + name, _ = named_buffer + # Get parent by splitting on dots and traversing the model + parent = module + if "." in name: + parent_name = name.rsplit(".", 1)[0] + for part in parent_name.split("."): + parent = getattr(parent, part) + name = name.split(".")[-1] + if name not in parent._non_persistent_buffers_set: + yield named_buffer + + +def compute_module_persistent_sizes( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, +): + """ + Compute the size of each submodule of a given model (parameters + persistent buffers). + """ + if dtype is not None: + dtype = _get_proper_dtype(dtype) + dtype_size = dtype_byte_size(dtype) + if special_dtypes is not None: + special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} + special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} + module_sizes = defaultdict(int) + + module_list = [] + + module_list = named_persistent_module_tensors(model, recurse=True) + + for name, tensor in module_list: + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes_size[name] + elif dtype is None: + size = tensor.numel() * dtype_byte_size(tensor.dtype) + elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + # According to the code in set_module_tensor_to_device, these types won't be converted + # so use their original size here + size = tensor.numel() * dtype_byte_size(tensor.dtype) + else: + size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) + name_parts = name.split(".") + for idx in range(len(name_parts) + 1): + module_sizes[".".join(name_parts[:idx])] += size + + return module_sizes + + class ModelUtilsTest(unittest.TestCase): def tearDown(self): super().tearDown() @@ -1012,7 +1080,7 @@ def test_cpu_offload(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_persistent_sizes(model)[""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] with tempfile.TemporaryDirectory() as tmp_dir: @@ -1042,7 +1110,7 @@ def test_disk_offload_without_safetensors(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_persistent_sizes(model)[""] with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, safe_serialization=False) @@ -1076,7 +1144,7 @@ def test_disk_offload_with_safetensors(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_persistent_sizes(model)[""] with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) @@ -1104,7 +1172,7 @@ def test_model_parallelism(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_persistent_sizes(model)[""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] with tempfile.TemporaryDirectory() as tmp_dir: @@ -1132,7 +1200,7 @@ def test_sharded_checkpoints(self): base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") @@ -1164,7 +1232,7 @@ def test_sharded_checkpoints_with_variant(self): base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. variant = "fp16" with tempfile.TemporaryDirectory() as tmp_dir: @@ -1204,7 +1272,7 @@ def test_sharded_checkpoints_device_map(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") @@ -1233,7 +1301,7 @@ def test_variant_sharded_ckpt_right_format(self): config, _ = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**config).eval() - model_size = compute_module_sizes(model)[""] + model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. variant = "fp16" with tempfile.TemporaryDirectory() as tmp_dir: From 3c2e2aa8a902ebaf57ea36e48a64b52dc9b2e7df Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 23 Dec 2024 11:27:25 +0800 Subject: [PATCH 78/88] `.from_single_file()` - Add missing `.shape` (#10332) Add missing `.shape` --- src/diffusers/models/model_loading_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index af1a1a5250ff..5f5ea2351709 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -228,7 +228,7 @@ def load_model_dict_into_meta( else: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( - f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) if is_quantized and ( From ffc0eaab6d8ae7176a34ebfff3f225c2e37ba187 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 11:03:04 +0530 Subject: [PATCH 79/88] Bump minimum TorchAO version to 0.7.0 (#10293) * bump min torchao version to 0.7.0 * update --- .../quantizers/torchao/torchao_quantizer.py | 5 + src/diffusers/utils/testing_utils.py | 4 +- tests/quantization/torchao/test_torchao.py | 94 +++++++++---------- 3 files changed, 52 insertions(+), 51 deletions(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 8b28a403e6f0..25cd4ad448e7 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -93,6 +93,11 @@ def validate_environment(self, *args, **kwargs): raise ImportError( "Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`" ) + torchao_version = version.parse(importlib.metadata.version("torch")) + if torchao_version < version.parse("0.7.0"): + raise RuntimeError( + f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`." + ) self.offload = False diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 3448b4d28d1f..3ae74cddcbbf 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -490,11 +490,11 @@ def decorator(test_case): return decorator -def require_torchao_version_greater(torchao_version): +def require_torchao_version_greater_or_equal(torchao_version): def decorator(test_case): correct_torchao_version = is_torchao_available() and version.parse( version.parse(importlib.metadata.version("torchao")).base_version - ) > version.parse(torchao_version) + ) >= version.parse(torchao_version) return unittest.skipUnless( correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}." )(test_case) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 6f9980c006ac..418fc997a215 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -36,7 +36,7 @@ nightly, require_torch, require_torch_gpu, - require_torchao_version_greater, + require_torchao_version_greater_or_equal, slow, torch_device, ) @@ -74,13 +74,13 @@ def forward(self, input, *args, **kwargs): if is_torchao_available(): from torchao.dtypes import AffineQuantizedTensor - from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + from torchao.utils import get_model_size_in_bytes @require_torch @require_torch_gpu -@require_torchao_version_greater("0.6.0") +@require_torchao_version_greater_or_equal("0.7.0") class TorchAoConfigTest(unittest.TestCase): def test_to_dict(self): """ @@ -125,7 +125,7 @@ def test_repr(self): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_gpu -@require_torchao_version_greater("0.6.0") +@require_torchao_version_greater_or_equal("0.7.0") class TorchAoTest(unittest.TestCase): def tearDown(self): gc.collect() @@ -139,11 +139,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig): quantization_config=quantization_config, torch_dtype=torch.bfloat16, ) - text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") - text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) + text_encoder_2 = T5EncoderModel.from_pretrained( + model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 + ) tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") - vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -212,7 +214,7 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0): def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]): components = self.get_dummy_components(quantization_config) pipe = FluxPipeline(**components) - pipe.to(device=torch_device, dtype=torch.bfloat16) + pipe.to(device=torch_device) inputs = self.get_dummy_inputs(torch_device) output = pipe(**inputs)[0] @@ -276,7 +278,6 @@ def test_int4wo_quant_bfloat16_conversion(self): self.assertTrue(isinstance(weight, AffineQuantizedTensor)) self.assertEqual(weight.quant_min, 0) self.assertEqual(weight.quant_max, 15) - self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) def test_device_map(self): """ @@ -341,21 +342,33 @@ def test_device_map(self): def test_modules_to_not_convert(self): quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) - quantized_model = FluxTransformer2DModel.from_pretrained( + quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained( "hf-internal-testing/tiny-flux-pipe", subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, ) - unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2] + unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2] self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear)) self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor)) self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16) - quantized_layer = quantized_model.proj_out + quantized_layer = quantized_model_with_not_convert.proj_out self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor)) - self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8) + + quantization_config = TorchAoConfig("int8_weight_only") + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert) + size_quantized = get_model_size_in_bytes(quantized_model) + + self.assertTrue(size_quantized < size_quantized_with_not_convert) def test_training(self): quantization_config = TorchAoConfig("int8_weight_only") @@ -406,23 +419,6 @@ def test_torch_compile(self): # Note: Seems to require higher tolerance self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) - @staticmethod - def _get_memory_footprint(module): - quantized_param_memory = 0.0 - unquantized_param_memory = 0.0 - - for param in module.parameters(): - if param.__class__.__name__ == "AffineQuantizedTensor": - data, scale, zero_point = param.layout_tensor.get_plain() - quantized_param_memory += data.numel() + data.element_size() - quantized_param_memory += scale.numel() + scale.element_size() - quantized_param_memory += zero_point.numel() + zero_point.element_size() - else: - unquantized_param_memory += param.data.numel() * param.data.element_size() - - total_memory = quantized_param_memory + unquantized_param_memory - return total_memory, quantized_param_memory, unquantized_param_memory - def test_memory_footprint(self): r""" A simple test to check if the model conversion has been done correctly by checking on the @@ -433,20 +429,18 @@ def test_memory_footprint(self): transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"] transformer_bf16 = self.get_dummy_components(None)["transformer"] - total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo) - total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint( - transformer_int4wo_gs32 - ) - total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo) - total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16) - - self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16) - # int4wo_gs32 has smaller group size, so more groups -> more scales and zero points - self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32) - # int4 with default group size quantized very few linear layers compared to a smaller group size of 32 - self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32) + total_int4wo = get_model_size_in_bytes(transformer_int4wo) + total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) + total_int8wo = get_model_size_in_bytes(transformer_int8wo) + total_bf16 = get_model_size_in_bytes(transformer_bf16) + + # Latter has smaller group size, so more groups -> more scales and zero points + self.assertTrue(total_int4wo < total_int4wo_gs32) # int8 quantizes more layers compare to int4 with default group size - self.assertTrue(quantized_int8wo < quantized_int4wo) + self.assertTrue(total_int8wo < total_int4wo) + # int4wo does not quantize too many layers because of default group size, but for the layers it does + # there is additional overhead of scales and zero points + self.assertTrue(total_bf16 < total_int4wo) def test_wrong_config(self): with self.assertRaises(ValueError): @@ -456,7 +450,7 @@ def test_wrong_config(self): # This class is not to be run as a test by itself. See the tests that follow this class @require_torch @require_torch_gpu -@require_torchao_version_greater("0.6.0") +@require_torchao_version_greater_or_equal("0.7.0") class TorchAoSerializationTest(unittest.TestCase): model_name = "hf-internal-testing/tiny-flux-pipe" quant_method, quant_method_kwargs = None, None @@ -565,7 +559,7 @@ class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_gpu -@require_torchao_version_greater("0.6.0") +@require_torchao_version_greater_or_equal("0.7.0") @slow @nightly class SlowTorchAoTests(unittest.TestCase): @@ -581,11 +575,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig): quantization_config=quantization_config, torch_dtype=torch.bfloat16, ) - text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") - text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) + text_encoder_2 = T5EncoderModel.from_pretrained( + model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 + ) tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") - vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -617,7 +613,7 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0): def _test_quant_type(self, quantization_config, expected_slice): components = self.get_dummy_components(quantization_config) - pipe = FluxPipeline(**components).to(dtype=torch.bfloat16) + pipe = FluxPipeline(**components) pipe.enable_model_cpu_offload() inputs = self.get_dummy_inputs(torch_device) From 6a970a45c5382f7153d81b924e06b736581a6c3f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 23 Dec 2024 11:03:50 +0530 Subject: [PATCH 80/88] [docs] fix: torchao example. (#10278) fix: torchao example. --- docs/source/en/quantization/torchao.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index bd5c7697a0f7..1f9f99a79a3b 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -27,7 +27,7 @@ The example below only quantizes the weights to int8. ```python from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig -model_id = "black-forest-labs/Flux.1-Dev" +model_id = "black-forest-labs/FLUX.1-dev" dtype = torch.bfloat16 quantization_config = TorchAoConfig("int8wo") @@ -45,7 +45,9 @@ pipe = FluxPipeline.from_pretrained( pipe.to("cuda") prompt = "A cat holding a sign that says hello world" -image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0] +image = pipe( + prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 +).images[0] image.save("output.png") ``` From 02c777c065c851720654ed2e69173aaf43d8600a Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 11:04:57 +0530 Subject: [PATCH 81/88] [tests] Refactor TorchAO serialization fast tests (#10271) refactor --- tests/quantization/torchao/test_torchao.py | 75 ++++++++++------------ 1 file changed, 35 insertions(+), 40 deletions(-) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 418fc997a215..0fa9182a3314 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -447,21 +447,19 @@ def test_wrong_config(self): self.get_dummy_components(TorchAoConfig("int42")) -# This class is not to be run as a test by itself. See the tests that follow this class +# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_gpu @require_torchao_version_greater_or_equal("0.7.0") class TorchAoSerializationTest(unittest.TestCase): model_name = "hf-internal-testing/tiny-flux-pipe" - quant_method, quant_method_kwargs = None, None - device = "cuda" def tearDown(self): gc.collect() torch.cuda.empty_cache() - def get_dummy_model(self, device=None): - quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs) + def get_dummy_model(self, quant_method, quant_method_kwargs, device=None): + quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs) quantized_model = FluxTransformer2DModel.from_pretrained( self.model_name, subfolder="transformer", @@ -497,15 +495,15 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0): "timestep": timestep, } - def test_original_model_expected_slice(self): - quantized_model = self.get_dummy_model(torch_device) + def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice): + quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device) inputs = self.get_dummy_tensor_inputs(torch_device) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3)) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) - def check_serialization_expected_slice(self, expected_slice): - quantized_model = self.get_dummy_model(self.device) + def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device): + quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device) with tempfile.TemporaryDirectory() as tmp_dir: quantized_model.save_pretrained(tmp_dir, safe_serialization=False) @@ -524,36 +522,33 @@ def check_serialization_expected_slice(self, expected_slice): ) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) - def test_serialization_expected_slice(self): - self.check_serialization_expected_slice(self.serialized_expected_slice) - - -class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest): - quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} - expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - serialized_expected_slice = expected_slice - device = "cuda" - - -class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest): - quant_method, quant_method_kwargs = "int8_weight_only", {} - expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) - serialized_expected_slice = expected_slice - device = "cuda" - - -class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest): - quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} - expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - serialized_expected_slice = expected_slice - device = "cpu" - - -class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): - quant_method, quant_method_kwargs = "int8_weight_only", {} - expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) - serialized_expected_slice = expected_slice - device = "cpu" + def test_int_a8w8_cuda(self): + quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = "cuda" + self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) + self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + + def test_int_a16w8_cuda(self): + quant_method, quant_method_kwargs = "int8_weight_only", {} + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) + device = "cuda" + self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) + self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + + def test_int_a8w8_cpu(self): + quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = "cpu" + self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) + self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + + def test_int_a16w8_cpu(self): + quant_method, quant_method_kwargs = "int8_weight_only", {} + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) + device = "cpu" + self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) + self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners From 76e2727b5c630fdad3b054c717e7ae4bdd5e2d8e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 23 Dec 2024 12:35:13 +0530 Subject: [PATCH 82/88] [SANA LoRA] sana lora training tests and misc. (#10296) * sana lora training tests and misc. * remove push to hub * Update examples/dreambooth/train_dreambooth_lora_sana.py Co-authored-by: Aryan --------- Co-authored-by: Aryan --- .../dreambooth/test_dreambooth_lora_sana.py | 206 ++++++++++++++++++ .../dreambooth/train_dreambooth_lora_sana.py | 23 +- tests/lora/test_lora_layers_sana.py | 20 +- tests/pipelines/sana/test_sana.py | 6 +- 4 files changed, 231 insertions(+), 24 deletions(-) create mode 100644 examples/dreambooth/test_dreambooth_lora_sana.py diff --git a/examples/dreambooth/test_dreambooth_lora_sana.py b/examples/dreambooth/test_dreambooth_lora_sana.py new file mode 100644 index 000000000000..dfceb09a9736 --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora_sana.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRASANA(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe" + script_path = "examples/dreambooth/train_dreambooth_lora_sana.py" + transformer_layer_type = "transformer_blocks.0.attn1.to_k" + + def test_dreambooth_lora_sana(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --resolution 32 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --resolution 32 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_layers(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --resolution 32 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lora_layers {self.transformer_layer_type} + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. In this test, we only params of + # `self.transformer_layer_type` should be in the state dict. + starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --resolution=32 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --resolution=32 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + --max_sequence_length 166 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --resolution=32 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + --max_sequence_length 16 + """.split() + + resume_run_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 4baa9f194feb..49c790ba04d7 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -943,7 +943,7 @@ def main(args): # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder="scheduler" + args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) text_encoder = Gemma2Model.from_pretrained( @@ -964,15 +964,6 @@ def main(args): vae.requires_grad_(False) text_encoder.requires_grad_(False) - # Initialize a text encoding pipeline and keep it to CPU for now. - text_encoding_pipeline = SanaPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=None, - transformer=None, - text_encoder=text_encoder, - tokenizer=tokenizer, - ) - # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 @@ -993,6 +984,15 @@ def main(args): # because Gemma2 is particularly suited for bfloat16. text_encoder.to(dtype=torch.bfloat16) + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = SanaPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() @@ -1182,6 +1182,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): ) if args.offload: text_encoding_pipeline = text_encoding_pipeline.to("cpu") + prompt_embeds = prompt_embeds.to(transformer.dtype) return prompt_embeds, prompt_attention_mask # If no type of tuning is done on the text_encoder and custom instance prompts are NOT @@ -1216,7 +1217,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): vae_config_scaling_factor = vae.config.scaling_factor if args.cache_latents: latents_cache = [] - vae = vae.to("cuda") + vae = vae.to(accelerator.device) for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index 499ca89262a0..78f71527cb7e 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -16,7 +16,7 @@ import unittest import torch -from transformers import Gemma2ForCausalLM, GemmaTokenizer +from transformers import Gemma2Model, GemmaTokenizer from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel from diffusers.utils.testing_utils import floats_tensor, require_peft_backend @@ -73,7 +73,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): } vae_cls = AutoencoderDC tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma" - text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers" + text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers" @property def output_shape(self): @@ -105,34 +105,34 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in Sana.") + @unittest.skip("Not supported in SANA.") def test_modify_padding_mode(self): pass - @unittest.skip("Not supported in Mochi.") + @unittest.skip("Not supported in SANA.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Mochi.") + @unittest.skip("Not supported in SANA.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py index f8551fff8447..21de4e04437a 100644 --- a/tests/pipelines/sana/test_sana.py +++ b/tests/pipelines/sana/test_sana.py @@ -18,7 +18,7 @@ import numpy as np import torch -from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer +from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel from diffusers.utils.testing_utils import ( @@ -101,7 +101,7 @@ def get_dummy_components(self): torch.manual_seed(0) text_encoder_config = Gemma2Config( head_dim=16, - hidden_size=32, + hidden_size=8, initializer_range=0.02, intermediate_size=64, max_position_embeddings=8192, @@ -112,7 +112,7 @@ def get_dummy_components(self): vocab_size=8, attn_implementation="eager", ) - text_encoder = Gemma2ForCausalLM(text_encoder_config) + text_encoder = Gemma2Model(text_encoder_config) tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") components = { From 5fcee4a4471d32d3a5959e55805303a7ec7a801e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 23 Dec 2024 13:12:23 +0530 Subject: [PATCH 83/88] [Single File] Fix loading (#10349) update --- src/diffusers/loaders/single_file_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 5933c634f4cc..6de9f0e9e638 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -626,7 +626,7 @@ def infer_diffusers_model_type(checkpoint): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): model_type = "mochi-1-preview" - if CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: + elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: model_type = "hunyuan-video" else: From c34fc3456387da14fdb4a2ae8eea714f72fcd429 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 23 Dec 2024 13:59:55 +0530 Subject: [PATCH 84/88] [Tests] QoL improvements to the LoRA test suite (#10304) * misc lora test improvements. * updates * fixes to tests --- tests/lora/test_lora_layers_flux.py | 93 +++++-------------- tests/lora/test_lora_layers_ltx_video.py | 47 +--------- tests/lora/utils.py | 110 +++++++++++++++++++++++ 3 files changed, 132 insertions(+), 118 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 10ea2de5ef88..b22fbaaed69b 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -36,7 +36,6 @@ numpy_cosine_similarity_distance, require_big_gpu_with_torch_cuda, require_peft_backend, - require_peft_version_greater, require_torch_gpu, slow, torch_device, @@ -331,7 +330,8 @@ def test_lora_parameter_expanded_shapes(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -340,85 +340,32 @@ def test_lora_parameter_expanded_shapes(self): self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - @require_peft_version_greater("0.13.2") - def test_lora_B_bias(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - # keep track of the bias values of the base layers to perform checks later. - bias_values = {} - for name, module in pipe.transformer.named_modules(): - if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]): - if module.bias is not None: - bias_values[name] = module.bias.data.clone() - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - logger = logging.get_logger("diffusers.loaders.lora_pipeline") - logger.setLevel(logging.INFO) - - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - denoiser_lora_config.lora_bias = False - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.delete_adapters("adapter-1") - - denoiser_lora_config.lora_bias = True - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3)) - self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) - self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) - - # for now this is flux control lora specific but can be generalized later and added to ./utils.py - def test_correct_lora_configs_with_different_ranks(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + # Testing opposite direction where the LoRA params are zero-padded. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.transformer.delete_adapters("adapter-1") - - # change the rank_pattern - updated_rank = denoiser_lora_config.r * 2 - denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank} - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - assert pipe.transformer.peft_config["adapter-1"].rank_pattern == { - "single_transformer_blocks.0.attn.to_k": updated_rank + dummy_lora_A = torch.nn.Linear(1, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, } + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") - lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) - self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) - pipe.transformer.delete_adapters("adapter-1") - - # similarly change the alpha_pattern - updated_alpha = denoiser_lora_config.lora_alpha * 2 - denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha} - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == { - "single_transformer_blocks.0.attn.to_k": updated_alpha - } + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] + lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) - self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) + self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) + self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) - def test_lora_expanding_shape_with_normal_lora(self): - # This test checks if it works when a lora with expanded shapes (like control loras) but - # another lora with correct shapes is loaded. The opposite direction isn't supported and is - # tested with it. + def test_normal_lora_with_expanded_lora_raises_error(self): + # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then + # load shape expanded LoRA (such as Control LoRA). components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) # Change the transformer config to mimic a real use case. diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index c9c877b202ef..1ed426f6e8dd 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -15,8 +15,6 @@ import sys import unittest -import numpy as np -import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -26,18 +24,12 @@ LTXPipeline, LTXVideoTransformer3DModel, ) -from diffusers.utils.testing_utils import ( - floats_tensor, - is_torch_version, - require_peft_backend, - skip_mps, - torch_device, -) +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend sys.path.append(".") -from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 +from utils import PeftLoraLoaderMixinTests # noqa: E402 @require_peft_backend @@ -107,41 +99,6 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @skip_mps - @pytest.mark.xfail( - condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), - reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", - strict=True, - ) - def test_lora_fuse_nan(self): - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - - # corrupt one LoRA weight with `inf` values - with torch.no_grad(): - pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") - - # with `safe_fusing=True` we should see an Error - with self.assertRaises(ValueError): - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) - - # without we should not see an error, but every image will be black - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) - - out = pipe( - "test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np" - )[0] - - self.assertTrue(np.isnan(out).all()) - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 0a0366fd8d2b..567b79677ffd 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1988,3 +1988,113 @@ def test_set_adapters_match_attention_kwargs(self): np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), "Loading from saved checkpoints should give same results as set_adapters().", ) + + @require_peft_version_greater("0.13.2") + def test_lora_B_bias(self): + # Currently, this test is only relevant for Flux Control LoRA as we are not + # aware of any other LoRA checkpoint that has its `lora_B` biases trained. + components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # keep track of the bias values of the base layers to perform checks later. + bias_values = {} + denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer + for name, module in denoiser.named_modules(): + if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]): + if module.bias is not None: + bias_values[name] = module.bias.data.clone() + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.INFO) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + denoiser_lora_config.lora_bias = False + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.delete_adapters("adapter-1") + + denoiser_lora_config.lora_bias = True + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) + + def test_correct_lora_configs_with_different_ranks(self): + components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + + if self.unet_kwargs is not None: + pipe.unet.delete_adapters("adapter-1") + else: + pipe.transformer.delete_adapters("adapter-1") + + denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer + for name, _ in denoiser.named_modules(): + if "to_k" in name and "attn" in name and "lora" not in name: + module_name_to_rank_update = name.replace(".base_layer.", ".") + break + + # change the rank_pattern + updated_rank = denoiser_lora_config.r * 2 + denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} + + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern + + self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank}) + + lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + + if self.unet_kwargs is not None: + pipe.unet.delete_adapters("adapter-1") + else: + pipe.transformer.delete_adapters("adapter-1") + + # similarly change the alpha_pattern + updated_alpha = denoiser_lora_config.lora_alpha * 2 + denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue( + pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} + ) + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue( + pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} + ) + + lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) + self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) From 71cc2013fe9a1cf3bbd9fdcdff5dbf7b2f8d9ee9 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 23 Dec 2024 08:50:06 +0000 Subject: [PATCH 85/88] Fix FluxIPAdapterTesterMixin (#10354) --- src/diffusers/loaders/transformer_flux.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index 52a48e56e748..9fe712bb12e9 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -177,3 +177,5 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) self.config.encoder_hid_dim_type = "ip_image_proj" + + self.to(dtype=self.dtype, device=self.device) From 055d95543a41a47901195c47462c2976e3de6de7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 14:22:09 +0530 Subject: [PATCH 86/88] Fix failing CogVideoX LoRA fuse test (#10352) fix --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4558d48edad9..1768c81ce039 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -748,10 +748,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): pos_embedding = self._get_positional_embeddings( height, width, pre_time_compression_frames, device=embeds.device ) - pos_embedding = pos_embedding.to(dtype=embeds.dtype) else: pos_embedding = self.pos_embedding + pos_embedding = pos_embedding.to(dtype=embeds.dtype) embeds = embeds + pos_embedding return embeds From 9d27df8071bb39d117755200ace81a3669b4134c Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 15:29:10 +0530 Subject: [PATCH 87/88] Rename LTX blocks and docs title (#10213) * rename blocks and docs * fix docs --------- Co-authored-by: Dhruv Nair --- docs/source/en/_toctree.yml | 2 +- .../en/api/models/autoencoderkl_ltx_video.md | 2 +- .../en/api/models/ltx_video_transformer3d.md | 2 +- .../models/autoencoders/autoencoder_kl_ltx.py | 75 ++++++++++--------- .../models/transformers/transformer_ltx.py | 16 ++-- 5 files changed, 49 insertions(+), 48 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 6ac66db73026..134a127d4320 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -429,7 +429,7 @@ - local: api/pipelines/ledits_pp title: LEDITS++ - local: api/pipelines/ltx_video - title: LTX + title: LTXVideo - local: api/pipelines/lumina title: Lumina-T2X - local: api/pipelines/marigold diff --git a/docs/source/en/api/models/autoencoderkl_ltx_video.md b/docs/source/en/api/models/autoencoderkl_ltx_video.md index 694b5ace6fdf..fbdb11e29cdd 100644 --- a/docs/source/en/api/models/autoencoderkl_ltx_video.md +++ b/docs/source/en/api/models/autoencoderkl_ltx_video.md @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet. ```python from diffusers import AutoencoderKLLTXVideo -vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda") +vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda") ``` ## AutoencoderKLLTXVideo diff --git a/docs/source/en/api/models/ltx_video_transformer3d.md b/docs/source/en/api/models/ltx_video_transformer3d.md index 8a60bc0432c6..fe2664cf685c 100644 --- a/docs/source/en/api/models/ltx_video_transformer3d.md +++ b/docs/source/en/api/models/ltx_video_transformer3d.md @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet. ```python from diffusers import LTXVideoTransformer3DModel -transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") ``` ## LTXVideoTransformer3DModel diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index ff202b980b95..a6cb943e09cc 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -28,7 +28,7 @@ from .vae import DecoderOutput, DiagonalGaussianDistribution -class LTXCausalConv3d(nn.Module): +class LTXVideoCausalConv3d(nn.Module): def __init__( self, in_channels: int, @@ -79,9 +79,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class LTXResnetBlock3d(nn.Module): +class LTXVideoResnetBlock3d(nn.Module): r""" - A 3D ResNet block used in the LTX model. + A 3D ResNet block used in the LTXVideo model. Args: in_channels (`int`): @@ -117,13 +117,13 @@ def __init__( self.nonlinearity = get_activation(non_linearity) self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) - self.conv1 = LTXCausalConv3d( + self.conv1 = LTXVideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal ) self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) self.dropout = nn.Dropout(dropout) - self.conv2 = LTXCausalConv3d( + self.conv2 = LTXVideoCausalConv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal ) @@ -131,7 +131,7 @@ def __init__( self.conv_shortcut = None if in_channels != out_channels: self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) - self.conv_shortcut = LTXCausalConv3d( + self.conv_shortcut = LTXVideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal ) @@ -157,7 +157,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return hidden_states -class LTXUpsampler3d(nn.Module): +class LTXVideoUpsampler3d(nn.Module): def __init__( self, in_channels: int, @@ -170,7 +170,7 @@ def __init__( out_channels = in_channels * stride[0] * stride[1] * stride[2] - self.conv = LTXCausalConv3d( + self.conv = LTXVideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, @@ -191,9 +191,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class LTXDownBlock3D(nn.Module): +class LTXVideoDownBlock3D(nn.Module): r""" - Down block used in the LTX model. + Down block used in the LTXVideo model. Args: in_channels (`int`): @@ -235,7 +235,7 @@ def __init__( resnets = [] for _ in range(num_layers): resnets.append( - LTXResnetBlock3d( + LTXVideoResnetBlock3d( in_channels=in_channels, out_channels=in_channels, dropout=dropout, @@ -250,7 +250,7 @@ def __init__( if spatio_temporal_scale: self.downsamplers = nn.ModuleList( [ - LTXCausalConv3d( + LTXVideoCausalConv3d( in_channels=in_channels, out_channels=in_channels, kernel_size=3, @@ -262,7 +262,7 @@ def __init__( self.conv_out = None if in_channels != out_channels: - self.conv_out = LTXResnetBlock3d( + self.conv_out = LTXVideoResnetBlock3d( in_channels=in_channels, out_channels=out_channels, dropout=dropout, @@ -300,9 +300,9 @@ def create_forward(*inputs): # Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d -class LTXMidBlock3d(nn.Module): +class LTXVideoMidBlock3d(nn.Module): r""" - A middle block used in the LTX model. + A middle block used in the LTXVideo model. Args: in_channels (`int`): @@ -335,7 +335,7 @@ def __init__( resnets = [] for _ in range(num_layers): resnets.append( - LTXResnetBlock3d( + LTXVideoResnetBlock3d( in_channels=in_channels, out_channels=in_channels, dropout=dropout, @@ -367,9 +367,9 @@ def create_forward(*inputs): return hidden_states -class LTXUpBlock3d(nn.Module): +class LTXVideoUpBlock3d(nn.Module): r""" - Up block used in the LTX model. + Up block used in the LTXVideo model. Args: in_channels (`int`): @@ -410,7 +410,7 @@ def __init__( self.conv_in = None if in_channels != out_channels: - self.conv_in = LTXResnetBlock3d( + self.conv_in = LTXVideoResnetBlock3d( in_channels=in_channels, out_channels=out_channels, dropout=dropout, @@ -421,12 +421,12 @@ def __init__( self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)]) + self.upsamplers = nn.ModuleList([LTXVideoUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)]) resnets = [] for _ in range(num_layers): resnets.append( - LTXResnetBlock3d( + LTXVideoResnetBlock3d( in_channels=out_channels, out_channels=out_channels, dropout=dropout, @@ -463,9 +463,9 @@ def create_forward(*inputs): return hidden_states -class LTXEncoder3d(nn.Module): +class LTXVideoEncoder3d(nn.Module): r""" - The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent representation. Args: @@ -509,7 +509,7 @@ def __init__( output_channel = block_out_channels[0] - self.conv_in = LTXCausalConv3d( + self.conv_in = LTXVideoCausalConv3d( in_channels=self.in_channels, out_channels=output_channel, kernel_size=3, @@ -524,7 +524,7 @@ def __init__( input_channel = output_channel output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] - down_block = LTXDownBlock3D( + down_block = LTXVideoDownBlock3D( in_channels=input_channel, out_channels=output_channel, num_layers=layers_per_block[i], @@ -536,7 +536,7 @@ def __init__( self.down_blocks.append(down_block) # mid block - self.mid_block = LTXMidBlock3d( + self.mid_block = LTXVideoMidBlock3d( in_channels=output_channel, num_layers=layers_per_block[-1], resnet_eps=resnet_norm_eps, @@ -546,14 +546,14 @@ def __init__( # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXCausalConv3d( + self.conv_out = LTXVideoCausalConv3d( in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal ) self.gradient_checkpointing = False def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `LTXEncoder3D` class.""" + r"""The forward method of the `LTXVideoEncoder3d` class.""" p = self.patch_size p_t = self.patch_size_t @@ -599,9 +599,10 @@ def create_forward(*inputs): return hidden_states -class LTXDecoder3d(nn.Module): +class LTXVideoDecoder3d(nn.Module): r""" - The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample. + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. Args: in_channels (`int`, defaults to 128): @@ -647,11 +648,11 @@ def __init__( layers_per_block = tuple(reversed(layers_per_block)) output_channel = block_out_channels[0] - self.conv_in = LTXCausalConv3d( + self.conv_in = LTXVideoCausalConv3d( in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal ) - self.mid_block = LTXMidBlock3d( + self.mid_block = LTXVideoMidBlock3d( in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal ) @@ -662,7 +663,7 @@ def __init__( input_channel = output_channel output_channel = block_out_channels[i] - up_block = LTXUpBlock3d( + up_block = LTXVideoUpBlock3d( in_channels=input_channel, out_channels=output_channel, num_layers=layers_per_block[i + 1], @@ -676,7 +677,7 @@ def __init__( # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXCausalConv3d( + self.conv_out = LTXVideoCausalConv3d( in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal ) @@ -777,7 +778,7 @@ def __init__( ) -> None: super().__init__() - self.encoder = LTXEncoder3d( + self.encoder = LTXVideoEncoder3d( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, @@ -788,7 +789,7 @@ def __init__( resnet_norm_eps=resnet_norm_eps, is_causal=encoder_causal, ) - self.decoder = LTXDecoder3d( + self.decoder = LTXVideoDecoder3d( in_channels=latent_channels, out_channels=out_channels, block_out_channels=block_out_channels, @@ -837,7 +838,7 @@ def __init__( self.tile_sample_stride_width = 448 def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LTXEncoder3d, LTXDecoder3d)): + if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)): module.gradient_checkpointing = value def enable_tiling( diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2ed8520a5d75..a895340bd124 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class LTXAttentionProcessor2_0: +class LTXVideoAttentionProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector. @@ -44,7 +44,7 @@ class LTXAttentionProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( - "LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( @@ -92,7 +92,7 @@ def __call__( return hidden_states -class LTXRotaryPosEmbed(nn.Module): +class LTXVideoRotaryPosEmbed(nn.Module): def __init__( self, dim: int, @@ -164,7 +164,7 @@ def forward( @maybe_allow_in_graph -class LTXTransformerBlock(nn.Module): +class LTXVideoTransformerBlock(nn.Module): r""" Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video). @@ -208,7 +208,7 @@ def __init__( cross_attention_dim=None, out_bias=attention_out_bias, qk_norm=qk_norm, - processor=LTXAttentionProcessor2_0(), + processor=LTXVideoAttentionProcessor2_0(), ) self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) @@ -221,7 +221,7 @@ def __init__( bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, - processor=LTXAttentionProcessor2_0(), + processor=LTXVideoAttentionProcessor2_0(), ) self.ff = FeedForward(dim, activation_fn=activation_fn) @@ -327,7 +327,7 @@ def __init__( self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - self.rope = LTXRotaryPosEmbed( + self.rope = LTXVideoRotaryPosEmbed( dim=inner_dim, base_num_frames=20, base_height=2048, @@ -339,7 +339,7 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - LTXTransformerBlock( + LTXVideoTransformerBlock( dim=inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, From ea1ba0ba53bdd6569547e26e518f094745ed9d03 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 23 Dec 2024 15:45:45 +0530 Subject: [PATCH 88/88] [LoRA] test fix (#10351) updates --- tests/lora/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 567b79677ffd..07563a84b5a6 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1568,7 +1568,7 @@ def test_lora_fuse_nan(self): # without we should not see an error, but every image will be black pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) - out = pipe("test", num_inference_steps=2, output_type="np")[0] + out = pipe(**inputs)[0] self.assertTrue(np.isnan(out).all())