Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyixuxu committed Dec 23, 2024
1 parent 4b55713 commit 42d3a6a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 37 deletions.
24 changes: 16 additions & 8 deletions src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class SanaModulatedNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)

def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
shift, scale = (
scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states


class SanaTransformerBlock(nn.Module):
r"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
Expand Down Expand Up @@ -288,8 +302,7 @@ def __init__(

# 4. Output blocks
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)

self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)

self.gradient_checkpointing = False
Expand Down Expand Up @@ -462,13 +475,8 @@ def custom_forward(*inputs):
)

# 3. Normalization
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)

# 4. Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)

# 5. Unpatchify
Expand Down
11 changes: 6 additions & 5 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import requests_mock
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size, compute_module_sizes
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
Expand Down Expand Up @@ -1080,7 +1080,7 @@ def test_cpu_offload(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def test_disk_offload_without_safetensors(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)

Expand Down Expand Up @@ -1144,7 +1144,7 @@ def test_disk_offload_with_safetensors(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

Expand Down Expand Up @@ -1172,7 +1172,7 @@ def test_model_parallelism(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -1183,6 +1183,7 @@ def test_model_parallelism(self):
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
print(f" new_model.hf_device_map:{new_model.hf_device_map}")

self.check_device_map_is_respected(new_model, new_model.hf_device_map)

Expand Down
25 changes: 1 addition & 24 deletions tests/models/transformers/test_models_transformer_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.9]

@property
def dummy_input(self):
Expand Down Expand Up @@ -81,27 +82,3 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_cpu_offload(self):
return super().test_cpu_offload()

@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_disk_offload_with_safetensors(self):
return super().test_disk_offload_with_safetensors()

@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_disk_offload_without_safetensors(self):
return super().test_disk_offload_without_safetensors()

0 comments on commit 42d3a6a

Please sign in to comment.