Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Flux multiple Lora loading bug #10388

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,13 +2460,17 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
if unexpected_modules:
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")

is_peft_loaded = getattr(transformer, "peft_config", None) is not None
transformer_base_layer_keys = {
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k
}
Comment on lines +2463 to +2465
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note base_layer substring can only be present when the underlying pipeline has at least one LoRA loaded that affects the layer under consideration. So, perhaps it's better to have an is_peft_loaded check?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your PR description you mention:

If the first loaded Lora model does not have weights for layer n, and the second one does, loading the second model will lead to an error since the transformer state dict currently does not have key n.base_layer.weight.

Note that we may also have an opposite situation i.e., the first LoRA ckpt may have the params while the second LoRA may not. This is what I considered in #10388.

for k in lora_module_names:
if k in unexpected_modules:
continue

base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
f"{k.replace(prefix, '')}.base_layer.weight"
if k in transformer_base_layer_keys
else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
Expand Down
Loading