Skip to content

Commit

Permalink
[Tests] QoL improvements to the LoRA test suite (#10304)
Browse files Browse the repository at this point in the history
* misc lora test improvements.

* updates

* fixes to tests
  • Loading branch information
sayakpaul authored Dec 23, 2024
1 parent 5fcee4a commit c34fc34
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 118 deletions.
93 changes: 20 additions & 73 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_peft_backend,
require_peft_version_greater,
require_torch_gpu,
slow,
torch_device,
Expand Down Expand Up @@ -331,7 +330,8 @@ def test_lora_parameter_expanded_shapes(self):
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand All @@ -340,85 +340,32 @@ def test_lora_parameter_expanded_shapes(self):
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

# keep track of the bias values of the base layers to perform checks later.
bias_values = {}
for name, module in pipe.transformer.named_modules():
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
if module.bias is not None:
bias_values[name] = module.bias.data.clone()

_, _, inputs = self.get_dummy_inputs(with_generator=False)

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)

original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

denoiser_lora_config.lora_bias = False
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.delete_adapters("adapter-1")

denoiser_lora_config.lora_bias = True
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))

# for now this is flux control lora specific but can be generalized later and added to ./utils.py
def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Testing opposite direction where the LoRA params are zero-padded.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.transformer.delete_adapters("adapter-1")

# change the rank_pattern
updated_rank = denoiser_lora_config.r * 2
denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank}
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
assert pipe.transformer.peft_config["adapter-1"].rank_pattern == {
"single_transformer_blocks.0.attn.to_k": updated_rank
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")

lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
pipe.transformer.delete_adapters("adapter-1")

# similarly change the alpha_pattern
updated_alpha = denoiser_lora_config.lora_alpha * 2
denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha}
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == {
"single_transformer_blocks.0.attn.to_k": updated_alpha
}
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)

def test_lora_expanding_shape_with_normal_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
# tested with it.
def test_normal_lora_with_expanded_lora_raises_error(self):
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
# load shape expanded LoRA (such as Control LoRA).
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)

# Change the transformer config to mimic a real use case.
Expand Down
47 changes: 2 additions & 45 deletions tests/lora/test_lora_layers_ltx_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import sys
import unittest

import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel

Expand All @@ -26,18 +24,12 @@
LTXPipeline,
LTXVideoTransformer3DModel,
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_torch_version,
require_peft_backend,
skip_mps,
torch_device,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend


sys.path.append(".")

from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from utils import PeftLoraLoaderMixinTests # noqa: E402


@require_peft_backend
Expand Down Expand Up @@ -107,41 +99,6 @@ def get_dummy_inputs(self, with_generator=True):

return noise, input_ids, pipeline_inputs

@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")

self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")

# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)

# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)

out = pipe(
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
)[0]

self.assertTrue(np.isnan(out).all())

def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)

Expand Down
110 changes: 110 additions & 0 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,3 +1988,113 @@ def test_set_adapters_match_attention_kwargs(self):
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as set_adapters().",
)

@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
# Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

# keep track of the bias values of the base layers to perform checks later.
bias_values = {}
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, module in denoiser.named_modules():
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
if module.bias is not None:
bias_values[name] = module.bias.data.clone()

_, _, inputs = self.get_dummy_inputs(with_generator=False)

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)

original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

denoiser_lora_config.lora_bias = False
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.delete_adapters("adapter-1")

denoiser_lora_config.lora_bias = True
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))

def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")

lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]

if self.unet_kwargs is not None:
pipe.unet.delete_adapters("adapter-1")
else:
pipe.transformer.delete_adapters("adapter-1")

denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, _ in denoiser.named_modules():
if "to_k" in name and "attn" in name and "lora" not in name:
module_name_to_rank_update = name.replace(".base_layer.", ".")
break

# change the rank_pattern
updated_rank = denoiser_lora_config.r * 2
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}

if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern

self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})

lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))

if self.unet_kwargs is not None:
pipe.unet.delete_adapters("adapter-1")
else:
pipe.transformer.delete_adapters("adapter-1")

# similarly change the alpha_pattern
updated_alpha = denoiser_lora_config.lora_alpha * 2
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(
pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
)
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
)

lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))

0 comments on commit c34fc34

Please sign in to comment.