Skip to content

Commit

Permalink
Support Flux IP Adapter (#10261)
Browse files Browse the repository at this point in the history
* Flux IP-Adapter

* test cfg

* make style

* temp remove copied from

* fix test

* fix test

* v2

* fix

* make style

* temp remove copied from

* Apply suggestions from code review

Co-authored-by: YiYi Xu <[email protected]>

* Move encoder_hid_proj to inside FluxTransformer2DModel

* merge

* separate encode_prompt, add copied from, image_encoder offload

* make

* fix test

* fix

* Update src/diffusers/pipelines/flux/pipeline_flux.py

* test_flux_prompt_embeds change not needed

* true_cfg -> true_cfg_scale

* fix merge conflict

* test_flux_ip_adapter_inference

* add fast test

* FluxIPAdapterMixin not test mixin

* Update pipeline_flux.py

Co-authored-by: YiYi Xu <[email protected]>

---------

Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
hlky and yiyixuxu authored Dec 21, 2024
1 parent bf9a641 commit be20709
Show file tree
Hide file tree
Showing 12 changed files with 1,157 additions and 14 deletions.
97 changes: 97 additions & 0 deletions scripts/convert_flux_xlabs_ipadapter_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
from contextlib import nullcontext

import safetensors.torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download

from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available


if is_transformers_available():
from transformers import CLIPVisionModelWithProjection

vision = True
else:
vision = False

"""
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
--filename "flux-ip-adapter.safetensors"
--output_path "flux-ip-adapter-hf/"
"""


CTX = init_empty_weights if is_accelerate_available else nullcontext

parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)

args = parser.parse_args()


def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")

original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict


def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
converted_state_dict = {}

# image_proj
## norm
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
## proj
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")

# double transformer blocks
for i in range(num_layers):
block_prefix = f"ip_adapter.{i}."
# to_k_ip
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
)
# to_v_ip
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
)

return converted_state_dict


def main(args):
original_ckpt = load_original_checkpoint(args)

num_layers = 19
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)

print("Saving Flux IP-Adapter in Diffusers format.")
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")

if vision:
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
model.save_pretrained(f"{args.output_path}/image_encoder")


if __name__ == "__main__":
main(args)
5 changes: 4 additions & 1 deletion src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder):

if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]

_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
Expand All @@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder):
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
]

Expand All @@ -86,12 +87,14 @@ def text_encoder_attn_modules(text_encoder):
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers

if is_transformers_available():
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
)
Expand Down
Loading

0 comments on commit be20709

Please sign in to comment.