Skip to content

Commit

Permalink
Merge branch 'main' into safety_checker
Browse files Browse the repository at this point in the history
  • Loading branch information
suzukimain authored Dec 23, 2024
2 parents 6a1532b + 3c2e2aa commit c158e09
Show file tree
Hide file tree
Showing 17 changed files with 285 additions and 49 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/api/models/autoencoder_kl_hunyuan_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanVideo

vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16)
```

## AutoencoderKLHunyuanVideo
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/models/hunyuan_video_transformer_3d.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import HunyuanVideoTransformer3DModel

transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16)
```

## HunyuanVideoTransformer3DModel
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/models/sana_transformer2d.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import SanaTransformer2DModel

transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```

## SanaTransformer2DModel
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/pipelines/hunyuan_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Recommendations for inference:
- Transformer should be in `torch.bfloat16`.
- VAE should be in `torch.float16`.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).

## HunyuanVideoPipeline
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/pipelines/sana.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ Available models:

| Model | Recommended dtype |
|:-----:|:-----------------:|
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
Expand Down
6 changes: 6 additions & 0 deletions scripts/convert_sana_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,18 @@ def main(args):
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")

# scheduler
flow_shift = 3.0

# model config
if args.model_type == "SanaMS_1600M_P1_D20":
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
layer_num = 28
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0}

for depth in range(layer_num):
# Transformer blocks.
Expand Down Expand Up @@ -176,6 +181,7 @@ def main(args):
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
interpolation_scale=interpolation_scale[args.image_size],
)

if is_accelerate_available():
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
Expand Down Expand Up @@ -101,6 +102,10 @@
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"HunyuanVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
}


Expand Down Expand Up @@ -220,6 +225,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
local_files_only = kwargs.pop("local_files_only", None)
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
Expand Down Expand Up @@ -297,7 +303,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
subfolder=subfolder,
local_files_only=local_files_only,
token=token,
revision=revision,
revision=config_revision,
)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)

Expand Down
135 changes: 135 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
}

DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
Expand Down Expand Up @@ -162,6 +163,7 @@
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
}

# Use to configure model sample size when original config is provided
Expand Down Expand Up @@ -624,6 +626,9 @@ def infer_diffusers_model_type(checkpoint):
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
model_type = "mochi-1-preview"

if CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"

else:
model_type = "v1"

Expand Down Expand Up @@ -2522,3 +2527,133 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")

return new_state_dict


def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight

def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key

if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)

def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v

def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v

def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072

if "linear1.weight" in key:
linear1_weight = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
state_dict[f"{new_key}.attn.to_q.weight"] = q
state_dict[f"{new_key}.attn.to_k.weight"] = k
state_dict[f"{new_key}.attn.to_v.weight"] = v
state_dict[f"{new_key}.proj_mlp.weight"] = mlp

elif "linear1.bias" in key:
linear1_bias = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias

else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)

TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}

def update_state_dict_(state_dict, old_key, new_key):
state_dict[new_key] = state_dict.pop(old_key)

for key in list(checkpoint.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(checkpoint, key, new_key)

for key in list(checkpoint.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, checkpoint)

return checkpoint
2 changes: 1 addition & 1 deletion src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def load_model_dict_into_meta(
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)

if is_quantized and (
Expand Down
48 changes: 33 additions & 15 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,39 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:


def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
try:
return next(parameter.parameters()).dtype
except StopIteration:
try:
return next(parameter.buffers()).dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5

def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples

gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
last_dtype = None
for param in parameter.parameters():
last_dtype = param.dtype
if param.is_floating_point():
return param.dtype

for buffer in parameter.buffers():
last_dtype = buffer.dtype
if buffer.is_floating_point():
return buffer.dtype

if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype

# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples

gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype

if last_tuple is not None:
# fallback to the last dtype
return last_tuple[1].dtype


class ModelMixin(torch.nn.Module, PushToHubMixin):
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,22 @@ def __init__(
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
) -> None:
super().__init__()

out_channels = out_channels or in_channels
inner_dim = num_attention_heads * attention_head_dim

# 1. Patch Embedding
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=None,
pos_embed_type=None,
interpolation_scale=interpolation_scale,
)

# 2. Additional condition embeddings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch.nn as nn
import torch.nn.functional as F

from diffusers.loaders import FromOriginalModelMixin

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
Expand Down Expand Up @@ -500,7 +502,7 @@ def forward(
return hidden_states, encoder_hidden_states


class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
>>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
>>> from diffusers.utils import export_to_video
>>> model_id = "tencent/HunyuanVideo"
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
... )
Expand Down
Loading

0 comments on commit c158e09

Please sign in to comment.