Skip to content

Commit

Permalink
add fast test
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Dec 20, 2024
1 parent 7938b42 commit 253ef7e
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 1 deletion.
52 changes: 52 additions & 0 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch

from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device

from ..test_modeling_common import ModelTesterMixin
Expand All @@ -26,6 +28,56 @@
enable_full_determinism()


def create_flux_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0

for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue

joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)

key_id += 1

# "image_proj" (ImageProjection layer weights)

image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
num_image_text_embeds=4,
)

ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)

del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict


class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
Expand Down
3 changes: 2 additions & 1 deletion tests/pipelines/flux/test_pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
)

from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)


class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
Expand Down
89 changes: 89 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict
from ..models.unets.test_models_unet_2d_condition import (
create_ip_adapter_faceid_state_dict,
create_ip_adapter_state_dict,
Expand Down Expand Up @@ -483,6 +484,94 @@ def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4):
)


class FluxIPAdapterTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for pipelines that support IP Adapters.
"""

def test_pipeline_signature(self):
parameters = inspect.signature(self.pipeline_class.__call__).parameters

assert issubclass(self.pipeline_class, FluxIPAdapterTesterMixin)
self.assertIn(
"ip_adapter_image",
parameters,
"`ip_adapter_image` argument must be supported by the `__call__` method",
)
self.assertIn(
"ip_adapter_image_embeds",
parameters,
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
)

def _get_dummy_image_embeds(self, image_embed_dim: int = 768):
return torch.randn((1, 1, image_embed_dim), device=torch_device)

def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
inputs["negative_prompt"] = ""
inputs["true_cfg_scale"] = 4.0
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs

def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff

components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
image_embed_dim = pipe.transformer.config.pooled_projection_dim

# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs)[0]
else:
output_without_adapter = expected_pipe_slice

adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer)
pipe.transformer._load_ip_adapter_weights(adapter_state_dict)

# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()

# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()

max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()

self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
)


class PipelineLatentTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
Expand Down

0 comments on commit 253ef7e

Please sign in to comment.