diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 027ab5fecefd..7add82ea876c 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -82,6 +82,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +class SanaModulatedNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6): + super().__init__() + self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + shift, scale = ( + scale_shift_table[None] + temb[:, None].to(scale_shift_table.device) + ).chunk(2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + return hidden_states + + class SanaTransformerBlock(nn.Module): r""" Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629). @@ -288,8 +302,7 @@ def __init__( # 4. Output blocks self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) - - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) self.gradient_checkpointing = False @@ -462,13 +475,8 @@ def custom_forward(*inputs): ) # 3. Normalization - shift, scale = ( - self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) - ).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) + hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table) - # 4. Modulation - hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) # 5. Unpatchify diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 4fc14804475a..629de2a16083 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -29,7 +29,7 @@ import requests_mock import torch import torch.nn as nn -from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size +from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size, compute_module_sizes from huggingface_hub import ModelCard, delete_repo, snapshot_download from huggingface_hub.utils import is_jinja_available from parameterized import parameterized @@ -1080,7 +1080,7 @@ def test_cpu_offload(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_persistent_sizes(model)[""] + model_size = compute_module_sizes(model)[""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] with tempfile.TemporaryDirectory() as tmp_dir: @@ -1110,7 +1110,7 @@ def test_disk_offload_without_safetensors(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_persistent_sizes(model)[""] + model_size = compute_module_sizes(model)[""] with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, safe_serialization=False) @@ -1144,7 +1144,7 @@ def test_disk_offload_with_safetensors(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_persistent_sizes(model)[""] + model_size = compute_module_sizes(model)[""] with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) @@ -1172,7 +1172,7 @@ def test_model_parallelism(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_persistent_sizes(model)[""] + model_size = compute_module_sizes(model)[""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] with tempfile.TemporaryDirectory() as tmp_dir: @@ -1183,6 +1183,7 @@ def test_model_parallelism(self): new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) + print(f" new_model.hf_device_map:{new_model.hf_device_map}") self.check_device_map_is_respected(new_model, new_model.hf_device_map) diff --git a/tests/models/transformers/test_models_transformer_sana.py b/tests/models/transformers/test_models_transformer_sana.py index 83db153dadea..d96c474d7163 100644 --- a/tests/models/transformers/test_models_transformer_sana.py +++ b/tests/models/transformers/test_models_transformer_sana.py @@ -33,6 +33,7 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = SanaTransformer2DModel main_input_name = "hidden_states" uses_custom_attn_processor = True + model_split_percents = [0.7, 0.7, 0.9] @property def dummy_input(self): @@ -81,27 +82,3 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"SanaTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - @pytest.mark.xfail( - condition=torch.device(torch_device).type == "cuda", - reason="Test currently fails.", - strict=True, - ) - def test_cpu_offload(self): - return super().test_cpu_offload() - - @pytest.mark.xfail( - condition=torch.device(torch_device).type == "cuda", - reason="Test currently fails.", - strict=True, - ) - def test_disk_offload_with_safetensors(self): - return super().test_disk_offload_with_safetensors() - - @pytest.mark.xfail( - condition=torch.device(torch_device).type == "cuda", - reason="Test currently fails.", - strict=True, - ) - def test_disk_offload_without_safetensors(self): - return super().test_disk_offload_without_safetensors()