From f099b2fbd5c715f59baccdc0dd51ea29b95209a3 Mon Sep 17 00:00:00 2001 From: Maxim Kan Date: Thu, 26 Dec 2024 10:02:30 +0000 Subject: [PATCH] check for base_layer key in transformer state dict --- src/diffusers/loaders/lora_pipeline.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 351295e938ff..7e26c397a077 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2460,13 +2460,17 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): if unexpected_modules: logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") - is_peft_loaded = getattr(transformer, "peft_config", None) is not None + transformer_base_layer_keys = { + k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k + } for k in lora_module_names: if k in unexpected_modules: continue base_param_name = ( - f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight" + f"{k.replace(prefix, '')}.base_layer.weight" + if k in transformer_base_layer_keys + else f"{k.replace(prefix, '')}.weight" ) base_weight_param = transformer_state_dict[base_param_name] lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]