Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
SHYuanBest authored Dec 23, 2024
2 parents 5fd9a81 + 7c2f0af commit 0937753
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/api/models/autoencoder_kl_hunyuan_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanVideo

vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16)
```

## AutoencoderKLHunyuanVideo
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/models/hunyuan_video_transformer_3d.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import HunyuanVideoTransformer3DModel

transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16)
```

## HunyuanVideoTransformer3DModel
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/pipelines/hunyuan_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Recommendations for inference:
- Transformer should be in `torch.bfloat16`.
- VAE should be in `torch.float16`.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).

## HunyuanVideoPipeline
Expand Down
48 changes: 33 additions & 15 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,39 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:


def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
try:
return next(parameter.parameters()).dtype
except StopIteration:
try:
return next(parameter.buffers()).dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5

def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples

gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
last_dtype = None
for param in parameter.parameters():
last_dtype = param.dtype
if param.is_floating_point():
return param.dtype

for buffer in parameter.buffers():
last_dtype = buffer.dtype
if buffer.is_floating_point():
return buffer.dtype

if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype

# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples

gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype

if last_tuple is not None:
# fallback to the last dtype
return last_tuple[1].dtype


class ModelMixin(torch.nn.Module, PushToHubMixin):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
>>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
>>> from diffusers.utils import export_to_video
>>> model_id = "tencent/HunyuanVideo"
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
... )
Expand Down
18 changes: 9 additions & 9 deletions src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,15 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)

def enable_sequential_cpu_offload(self, gpu_id=0):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)

def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
Expand Down Expand Up @@ -411,16 +411,16 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)

def enable_sequential_cpu_offload(self, gpu_id=0):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)

def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
Expand Down Expand Up @@ -652,16 +652,16 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)

def enable_sequential_cpu_offload(self, gpu_id=0):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)

def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
Expand Down

0 comments on commit 0937753

Please sign in to comment.