Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Dec 25, 2024
1 parent a01cb45 commit da96621
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
22 changes: 9 additions & 13 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,17 +1263,13 @@ def load_lora_weights(
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
if len(transformer_state_dict) > 0:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=None,
Expand Down Expand Up @@ -1809,12 +1805,12 @@ def load_lora_weights(
transformer_lora_state_dict = {
k: state_dict.get(k)
for k in list(state_dict.keys())
if k.startswith(self.transformer_name) and "lora" in k
if k.startswith(f"{self.transformer_name}.") and "lora" in k
}
transformer_norm_state_dict = {
k: state_dict.pop(k)
for k in list(state_dict.keys())
if k.startswith(self.transformer_name)
if k.startswith(f"{self.transformer_name}.")
and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
}

Expand Down
7 changes: 3 additions & 4 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,8 @@ def test_with_norm_in_state_dict(self):
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue(
cap_logger.out.startswith(
"The provided state dict contains normalization layers in addition to LoRA layers"
)
"The provided state dict contains normalization layers in addition to LoRA layers"
in cap_logger.out
)
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)

Expand All @@ -284,7 +283,7 @@ def test_with_norm_in_state_dict(self):
pipe.load_lora_weights(norm_state_dict)

self.assertTrue(
cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers")
"Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
)

def test_lora_parameter_expanded_shapes(self):
Expand Down

0 comments on commit da96621

Please sign in to comment.