From a756694bf0f4d2a1bba770586bcb7670235d296a Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 21 Dec 2024 14:10:32 +0000 Subject: [PATCH 1/4] Fix push_tests_mps.yml (#10326) --- .github/workflows/push_tests_mps.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml index 8d521074a08f..5fd3b78be7df 100644 --- a/.github/workflows/push_tests_mps.yml +++ b/.github/workflows/push_tests_mps.yml @@ -46,7 +46,7 @@ jobs: shell: arch -arch arm64 bash {0} run: | ${CONDA_RUN} python -m pip install --upgrade pip uv - ${CONDA_RUN} python -m uv pip install -e [quality,test] + ${CONDA_RUN} python -m uv pip install -e ".[quality,test]" ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git ${CONDA_RUN} python -m uv pip install transformers --upgrade From bf9a641f1a51368af5f3ae99cc460107d4fa2103 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 21 Dec 2024 14:10:44 +0000 Subject: [PATCH 2/4] Fix EMAModel test_from_pretrained (#10325) --- tests/others/test_ema.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 3443e6366f01..7cf8f30ecc44 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -67,6 +67,7 @@ def test_from_pretrained(self): # Load the EMA model from the saved directory loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False) + loaded_ema_unet.to(torch_device) # Check that the shadow parameters of the loaded model match the original EMA model for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): @@ -221,6 +222,7 @@ def test_from_pretrained(self): # Load the EMA model from the saved directory loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True) + loaded_ema_unet.to(torch_device) # Check that the shadow parameters of the loaded model match the original EMA model for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): From be2070991f1b916977020c45ecdfec225de21862 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 21 Dec 2024 17:49:58 +0000 Subject: [PATCH 3/4] Support Flux IP Adapter (#10261) * Flux IP-Adapter * test cfg * make style * temp remove copied from * fix test * fix test * v2 * fix * make style * temp remove copied from * Apply suggestions from code review Co-authored-by: YiYi Xu * Move encoder_hid_proj to inside FluxTransformer2DModel * merge * separate encode_prompt, add copied from, image_encoder offload * make * fix test * fix * Update src/diffusers/pipelines/flux/pipeline_flux.py * test_flux_prompt_embeds change not needed * true_cfg -> true_cfg_scale * fix merge conflict * test_flux_ip_adapter_inference * add fast test * FluxIPAdapterMixin not test mixin * Update pipeline_flux.py Co-authored-by: YiYi Xu --------- Co-authored-by: YiYi Xu --- ...nvert_flux_xlabs_ipadapter_to_diffusers.py | 97 ++++++ src/diffusers/loaders/__init__.py | 5 +- src/diffusers/loaders/ip_adapter.py | 286 ++++++++++++++++++ src/diffusers/loaders/transformer_flux.py | 179 +++++++++++ src/diffusers/models/attention_processor.py | 146 ++++++++- src/diffusers/models/embeddings.py | 2 +- .../models/transformers/transformer_flux.py | 20 +- src/diffusers/pipelines/flux/pipeline_flux.py | 178 ++++++++++- .../pipelines/flux/pipeline_flux_control.py | 1 - .../test_models_transformer_flux.py | 52 ++++ tests/pipelines/flux/test_pipeline_flux.py | 114 ++++++- tests/pipelines/test_pipelines_common.py | 91 +++++- 12 files changed, 1157 insertions(+), 14 deletions(-) create mode 100644 scripts/convert_flux_xlabs_ipadapter_to_diffusers.py create mode 100644 src/diffusers/loaders/transformer_flux.py diff --git a/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py b/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py new file mode 100644 index 000000000000..b701b7fb40b1 --- /dev/null +++ b/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py @@ -0,0 +1,97 @@ +import argparse +from contextlib import nullcontext + +import safetensors.torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available + + +if is_transformers_available(): + from transformers import CLIPVisionModelWithProjection + + vision = True +else: + vision = False + +""" +python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \ +--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \ +--filename "flux-ip-adapter.safetensors" +--output_path "flux-ip-adapter-hf/" +""" + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_state_dict_repo_id", default=None, type=str) +parser.add_argument("--filename", default="flux.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", type=str) +parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str) + +args = parser.parse_args() + + +def load_original_checkpoint(args): + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers): + converted_state_dict = {} + + # image_proj + ## norm + converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight") + converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias") + ## proj + converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight") + converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"ip_adapter.{i}." + # to_k_ip + converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop( + f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias" + ) + converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop( + f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight" + ) + # to_v_ip + converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop( + f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias" + ) + converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop( + f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight" + ) + + return converted_state_dict + + +def main(args): + original_ckpt = load_original_checkpoint(args) + + num_layers = 19 + converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers) + + print("Saving Flux IP-Adapter in Diffusers format.") + safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors") + + if vision: + model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path) + model.save_pretrained(f"{args.output_path}/image_encoder") + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index c7ea0be55db2..2db8b53db498 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] - + _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"] _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] @@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder): _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = [ "IPAdapterMixin", + "FluxIPAdapterMixin", "SD3IPAdapterMixin", ] @@ -86,12 +87,14 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin + from .transformer_flux import FluxTransformer2DLoadersMixin from .transformer_sd3 import SD3Transformer2DLoadersMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers if is_transformers_available(): from .ip_adapter import ( + FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin, ) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 11ce4f1634d7..7b691d1fe16e 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -38,6 +38,8 @@ from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, + FluxAttnProcessor2_0, + FluxIPAdapterJointAttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, @@ -353,6 +355,290 @@ def unload_ip_adapter(self): self.unet.set_attn_processor(attn_procs) +class FluxIPAdapterMixin: + """Mixin for handling Flux IP Adapters.""" + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], + weight_name: Union[str, List[str]], + subfolder: Optional[Union[str, List[str]]] = "", + image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder", + image_encoder_subfolder: Optional[str] = "", + image_encoder_dtype: torch.dtype = torch.float16, + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + subfolder (`str` or `List[str]`): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + weight_name (`str` or `List[str]`): + The name of the weight file to load. If a list is passed, it should have the same length as + `weight_name`. + image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`): + Can be either: + + - A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model + hosted on the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + + # handle the list inputs for multiple IP Adapters + if not isinstance(weight_name, list): + weight_name = [weight_name] + + if not isinstance(pretrained_model_name_or_path_or_dict, list): + pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] + if len(pretrained_model_name_or_path_or_dict) == 1: + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) + + if not isinstance(subfolder, list): + subfolder = [subfolder] + if len(subfolder) == 1: + subfolder = subfolder * len(weight_name) + + if len(weight_name) != len(pretrained_model_name_or_path_or_dict): + raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") + + if len(weight_name) != len(subfolder): + raise ValueError("`weight_name` and `subfolder` must have the same length.") + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + state_dicts = [] + for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( + pretrained_model_name_or_path_or_dict, weight_name, subfolder + ): + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + image_proj_keys = ["ip_adapter_proj_model.", "image_proj."] + ip_adapter_keys = ["double_blocks.", "ip_adapter."] + for key in f.keys(): + if any(key.startswith(prefix) for prefix in image_proj_keys): + diffusers_name = ".".join(key.split(".")[1:]) + state_dict["image_proj"][diffusers_name] = f.get_tensor(key) + elif any(key.startswith(prefix) for prefix in ip_adapter_keys): + diffusers_name = ( + ".".join(key.split(".")[1:]) + .replace("ip_adapter_double_stream_k_proj", "to_k_ip") + .replace("ip_adapter_double_stream_v_proj", "to_v_ip") + .replace("processor.", "") + ) + state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if keys != ["image_proj", "ip_adapter"]: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + state_dicts.append(state_dict) + + # load CLIP image encoder here if it has not been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if image_encoder_pretrained_model_name_or_path is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}") + image_encoder = ( + CLIPVisionModelWithProjection.from_pretrained( + image_encoder_pretrained_model_name_or_path, + subfolder=image_encoder_subfolder, + low_cpu_mem_usage=low_cpu_mem_usage, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + .to(self.device, dtype=image_encoder_dtype) + .eval() + ) + self.register_modules(image_encoder=image_encoder) + else: + raise ValueError( + "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." + ) + else: + logger.warning( + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + ) + + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: + # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224 + default_clip_size = 224 + clip_image_size = ( + self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size + ) + feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) + self.register_modules(feature_extractor=feature_extractor) + + # load ip-adapter into transformer + self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + + def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]): + """ + Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for + granular control over each IP-Adapter behavior. A config can be a float or a list. + + `float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]` + length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the + number of IP adapters and each must match the number of blocks. + + Example: + + ```py + # To use original IP-Adapter + scale = 1.0 + pipeline.set_ip_adapter_scale(scale) + + + def LinearStrengthModel(start, finish, size): + return [(start + (finish - start) * (i / (size - 1))) for i in range(size)] + + + ip_strengths = LinearStrengthModel(0.3, 0.92, 19) + pipeline.set_ip_adapter_scale(ip_strengths) + ``` + """ + transformer = self.transformer + if not isinstance(scale, list): + scale = [[scale] * transformer.config.num_layers] + elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float): + if len(scale) != transformer.config.num_layers: + raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.") + scale = [scale] + + scale_configs = scale + + key_id = 0 + for attn_name, attn_processor in transformer.attn_processors.items(): + if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)): + if len(scale_configs) != len(attn_processor.scale): + raise ValueError( + f"Cannot assign {len(scale_configs)} scale_configs to " + f"{len(attn_processor.scale)} IP-Adapter." + ) + elif len(scale_configs) == 1: + scale_configs = scale_configs * len(attn_processor.scale) + for i, scale_config in enumerate(scale_configs): + attn_processor.scale[i] = scale_config[key_id] + key_id += 1 + + def unload_ip_adapter(self): + """ + Unloads the IP Adapter weights + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # remove CLIP image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=[None, None]) + + # remove feature extractor only when safety_checker is None as safety_checker uses + # the feature_extractor later + if not hasattr(self, "safety_checker"): + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=[None, None]) + + # remove hidden encoder + self.transformer.encoder_hid_proj = None + self.transformer.config.encoder_hid_dim_type = None + + # restore original Transformer attention processors layers + attn_procs = {} + for name, value in self.transformer.attn_processors.items(): + attn_processor_class = FluxAttnProcessor2_0() + attn_procs[name] = ( + attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__() + ) + self.transformer.set_attn_processor(attn_procs) + + class SD3IPAdapterMixin: """Mixin for handling StableDiffusion 3 IP Adapters.""" diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py new file mode 100644 index 000000000000..52a48e56e748 --- /dev/null +++ b/src/diffusers/loaders/transformer_flux.py @@ -0,0 +1,179 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext + +from ..models.embeddings import ( + ImageProjection, + MultiIPAdapterImageProjection, +) +from ..models.modeling_utils import load_model_dict_into_meta +from ..utils import ( + is_accelerate_available, + is_torch_version, + logging, +) + + +if is_accelerate_available(): + pass + +logger = logging.get_logger(__name__) + + +class FluxTransformer2DLoadersMixin: + """ + Load layers into a [`FluxTransformer2DModel`]. + """ + + def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False): + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + updated_state_dict = {} + image_projection = None + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + + if "proj.weight" in state_dict: + # IP-Adapter + num_image_text_embeds = 4 + if state_dict["proj.weight"].shape[0] == 65536: + num_image_text_embeds = 16 + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds + + with init_context(): + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj", "image_embeds") + updated_state_dict[diffusers_name] = value + + if not low_cpu_mem_usage: + image_projection.load_state_dict(updated_state_dict, strict=True) + else: + load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + + return image_projection + + def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): + from ..models.attention_processor import ( + FluxIPAdapterJointAttnProcessor2_0, + ) + + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + # set ip-adapter cross-attention processors & load state_dict + attn_procs = {} + key_id = 0 + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + for name in self.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + attn_processor_class = self.attn_processors[name].__class__ + attn_procs[name] = attn_processor_class() + else: + cross_attention_dim = self.config.joint_attention_dim + hidden_size = self.inner_dim + attn_processor_class = FluxIPAdapterJointAttnProcessor2_0 + num_image_text_embeds = [] + for state_dict in state_dicts: + if "proj.weight" in state_dict["image_proj"]: + num_image_text_embed = 4 + if state_dict["image_proj"]["proj.weight"].shape[0] == 65536: + num_image_text_embed = 16 + # IP-Adapter + num_image_text_embeds += [num_image_text_embed] + + with init_context(): + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, + dtype=self.dtype, + device=self.device, + ) + + value_dict = {} + for i, state_dict in enumerate(state_dicts): + value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) + value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) + value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]}) + value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]}) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(value_dict) + else: + device = self.device + dtype = self.dtype + load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) + + key_id += 1 + + return attn_procs + + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): + if not isinstance(state_dicts, list): + state_dicts = [state_dicts] + + self.encoder_hid_proj = None + + attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + self.set_attn_processor(attn_procs) + + image_projection_layers = [] + for state_dict in state_dicts: + image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( + state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage + ) + image_projection_layers.append(image_projection_layer) + + self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ed0dd4f71d27..6e1dc1037c20 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -575,7 +575,7 @@ def forward( # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks"} + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] @@ -2653,6 +2653,149 @@ def __call__( return hidden_states +class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module): + """Flux Attention processor for IP-Adapter.""" + + def __init__( + self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + self.to_v_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ip_hidden_states: Optional[List[torch.Tensor]] = None, + ip_adapter_masks: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + hidden_states_query_proj = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP-adapter + ip_query = hidden_states_query_proj + ip_attn_output = None + # for ip-adapter + # TODO: support for multiple adapters + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_attn_output = F.scaled_dot_product_attention( + ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_attn_output = scale * ip_attn_output + ip_attn_output = ip_attn_output.to(ip_query.dtype) + + return hidden_states, encoder_hidden_states, ip_attn_output + else: + return hidden_states + + class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -5896,6 +6039,7 @@ def __call__( SlicedAttnProcessor, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, + FluxIPAdapterJointAttnProcessor2_0, ) AttentionProcessor = Union[ diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index f1b339e6180b..4558d48edad9 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1535,7 +1535,7 @@ def forward(self, image_embeds: torch.Tensor): batch_size = image_embeds.shape[0] # image - image_embeds = self.image_embeds(image_embeds) + image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype)) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) image_embeds = self.norm(image_embeds) return image_embeds diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 8dbe49b75076..dc2eb26f9d30 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import FeedForward from ...models.attention_processor import ( Attention, @@ -177,13 +177,18 @@ def forward( ) joint_attention_kwargs = joint_attention_kwargs or {} # Attention. - attn_output, context_attn_output = self.attn( + attention_outputs = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + # Process attention outputs for the `hidden_states`. attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output @@ -195,6 +200,8 @@ def forward( ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output # Process attention outputs for the `encoder_hidden_states`. @@ -212,7 +219,9 @@ def forward( return encoder_hidden_states, hidden_states -class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class FluxTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin +): """ The Transformer model introduced in Flux. @@ -482,6 +491,11 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index ec2801625552..181f0269ce3e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -17,10 +17,17 @@ import numpy as np import torch -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) -from ...image_processor import VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -142,6 +149,7 @@ class FluxPipeline( FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, + FluxIPAdapterMixin, ): r""" The Flux pipeline for text-to-image generation. @@ -169,8 +177,8 @@ class FluxPipeline( [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = [] + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( @@ -182,6 +190,8 @@ def __init__( text_encoder_2: T5EncoderModel, tokenizer_2: T5TokenizerFast, transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, ): super().__init__() @@ -193,6 +203,8 @@ def __init__( tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -377,14 +389,60 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers + ): + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + + image_embeds.append(single_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + def check_inputs( self, prompt, prompt_2, height, width, + negative_prompt=None, + negative_prompt_2=None, prompt_embeds=None, + negative_prompt_embeds=None, pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -419,10 +477,33 @@ def check_inputs( elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @@ -551,6 +632,9 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, @@ -561,6 +645,12 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -610,6 +700,17 @@ def __call__( pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -647,8 +748,12 @@ def __call__( prompt_2, height, width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -670,6 +775,7 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) + do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None ( prompt_embeds, pooled_prompt_embeds, @@ -684,6 +790,21 @@ def __call__( max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -725,12 +846,43 @@ def __call__( else: guidance = None + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -746,6 +898,22 @@ def __call__( return_dict=False, )[0] + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index dc3ca8cf7b09..ac8474becb78 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -403,7 +403,6 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs def check_inputs( self, prompt, diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 4a784eee4732..c88b3dac8216 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -18,6 +18,8 @@ import torch from diffusers import FluxTransformer2DModel +from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 +from diffusers.models.embeddings import ImageProjection from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin @@ -26,6 +28,56 @@ enable_full_determinism() +def create_flux_ip_adapter_state_dict(model): + # "ip_adapter" (cross-attention weights) + ip_cross_attn_state_dict = {} + key_id = 0 + + for name in model.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + continue + + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + sd = FluxIPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], + } + ) + + key_id += 1 + + # "image_proj" (ImageProjection layer weights) + + image_projection = ImageProjection( + cross_attention_dim=model.config["joint_attention_dim"], + image_embed_dim=model.config["pooled_projection_dim"], + num_image_text_embeds=4, + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel main_input_name = "hidden_states" diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index df9021ee0adb..7981e6c2a93b 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -16,13 +16,14 @@ ) from ..test_pipelines_common import ( + FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, ) -class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): +class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -91,6 +92,8 @@ def get_dummy_components(self): "tokenizer_2": tokenizer_2, "transformer": transformer, "vae": vae, + "image_encoder": None, + "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0): @@ -296,3 +299,112 @@ def test_flux_inference(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) assert max_diff < 1e-4 + + +@slow +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda +class FluxIPAdapterPipelineSlowTests(unittest.TestCase): + pipeline_class = FluxPipeline + repo_id = "black-forest-labs/FLUX.1-dev" + image_encoder_pretrained_model_name_or_path = "openai/clip-vit-large-patch14" + weight_name = "ip_adapter.safetensors" + ip_adapter_repo_id = "XLabs-AI/flux-ip-adapter" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + prompt_embeds = torch.load( + hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt") + ) + pooled_prompt_embeds = torch.load( + hf_hub_download( + repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt" + ) + ) + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + ip_adapter_image = np.zeros((1024, 1024, 3), dtype=np.uint8) + return { + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_pooled_prompt_embeds": negative_pooled_prompt_embeds, + "ip_adapter_image": ip_adapter_image, + "num_inference_steps": 2, + "guidance_scale": 3.5, + "true_cfg_scale": 4.0, + "max_sequence_length": 256, + "output_type": "np", + "generator": generator, + } + + def test_flux_ip_adapter_inference(self): + pipe = self.pipeline_class.from_pretrained( + self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None + ) + pipe.load_ip_adapter( + self.ip_adapter_repo_id, + weight_name=self.weight_name, + image_encoder_pretrained_model_name_or_path=self.image_encoder_pretrained_model_name_or_path, + ) + pipe.set_ip_adapter_scale(1.0) + pipe.enable_model_cpu_offload() + + inputs = self.get_inputs(torch_device) + + image = pipe(**inputs).images[0] + image_slice = image[0, :10, :10] + + expected_slice = np.array( + [ + 0.1855, + 0.1680, + 0.1406, + 0.1953, + 0.1699, + 0.1465, + 0.2012, + 0.1738, + 0.1484, + 0.2051, + 0.1797, + 0.1523, + 0.2012, + 0.1719, + 0.1445, + 0.2070, + 0.1777, + 0.1465, + 0.2090, + 0.1836, + 0.1484, + 0.2129, + 0.1875, + 0.1523, + 0.2090, + 0.1816, + 0.1484, + 0.2110, + 0.1836, + 0.1543, + ], + dtype=np.float32, + ) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) + + assert max_diff < 1e-4, f"{image_slice} != {expected_slice}" diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 4d2b534c9a28..764be1890cc5 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -29,7 +29,7 @@ UNet2DConditionModel, ) from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import IPAdapterMixin +from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.models.attention_processor import AttnProcessor from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel @@ -54,6 +54,7 @@ get_autoencoder_tiny_config, get_consistency_vae_config, ) +from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict from ..models.unets.test_models_unet_2d_condition import ( create_ip_adapter_faceid_state_dict, create_ip_adapter_state_dict, @@ -483,6 +484,94 @@ def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4): ) +class FluxIPAdapterTesterMixin: + """ + This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. + It provides a set of common tests for pipelines that support IP Adapters. + """ + + def test_pipeline_signature(self): + parameters = inspect.signature(self.pipeline_class.__call__).parameters + + assert issubclass(self.pipeline_class, FluxIPAdapterMixin) + self.assertIn( + "ip_adapter_image", + parameters, + "`ip_adapter_image` argument must be supported by the `__call__` method", + ) + self.assertIn( + "ip_adapter_image_embeds", + parameters, + "`ip_adapter_image_embeds` argument must be supported by the `__call__` method", + ) + + def _get_dummy_image_embeds(self, image_embed_dim: int = 768): + return torch.randn((1, 1, image_embed_dim), device=torch_device) + + def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): + inputs["negative_prompt"] = "" + inputs["true_cfg_scale"] = 4.0 + inputs["output_type"] = "np" + inputs["return_dict"] = False + return inputs + + def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None): + r"""Tests for IP-Adapter. + + The following scenarios are tested: + - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + """ + # Raising the tolerance for this test when it's run on a CPU because we + # compare against static slices and that can be shaky (with a VVVV low probability). + expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + pipe.set_progress_bar_config(disable=None) + image_embed_dim = pipe.transformer.config.pooled_projection_dim + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + if expected_pipe_slice is None: + output_without_adapter = pipe(**inputs)[0] + else: + output_without_adapter = expected_pipe_slice + + adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer) + pipe.transformer._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] + inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs)[0] + if expected_pipe_slice is not None: + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] + inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs)[0] + if expected_pipe_slice is not None: + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" + ) + + class PipelineLatentTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. From 233dffdc3f56b26abaaba8363a5dd30dab7f0e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mehmet=20Yi=C4=9Fit=20=C3=96zgen=C3=A7?= <47952284+yigitozgenc@users.noreply.github.com> Date: Sat, 21 Dec 2024 21:44:43 +0300 Subject: [PATCH 4/4] flux controlnet inpaint config bug (#10291) * flux controlnet inpaint config bug * Update src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py --------- Co-authored-by: yigitozgenc Co-authored-by: hlky --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index c557cf134b05..85943b278dc6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1095,7 +1095,11 @@ def __call__( timestep = t.expand(latents.shape[0]).to(latents.dtype) # predict the noise residual - if self.controlnet.config.guidance_embeds: + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + if use_guidance: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) else: