Skip to content

Commit

Permalink
fixes to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Dec 23, 2024
1 parent d1dc1d8 commit 618d206
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def test_lora_parameter_expanded_shapes(self):
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand All @@ -339,6 +340,7 @@ def test_lora_parameter_expanded_shapes(self):
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

# Testing opposite direction where the LoRA params are zero-padded.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
Expand All @@ -349,15 +351,21 @@ def test_lora_parameter_expanded_shapes(self):
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
# We should error out because lora input features is less than original. We only
# support expanding the module, not shrinking it
with self.assertRaises(NotImplementedError):
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")

def test_lora_expanding_shape_with_normal_lora_raises_error(self):
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)

def test_normal_lora_with_expanded_lora_raises_error(self):
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
# load shape expanded LoRA (such as Control LoRA).
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)

# Change the transformer config to mimic a real use case.
Expand Down

0 comments on commit 618d206

Please sign in to comment.