Skip to content

Commit

Permalink
attemp to simplfy & correct to
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyixuxu committed Dec 24, 2024
1 parent 42d3a6a commit 583a7e9
Showing 1 changed file with 26 additions and 32 deletions.
58 changes: 26 additions & 32 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,49 +391,34 @@ def to(self, *args, **kwargs):
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())

# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
pipeline_is_sequentially_offloaded = hasattr(self, "_all_sequential_hooks") and self._all_sequential_hooks is not None and len(self._all_sequential_hooks) > 0
pipeline_is_offloaded = hasattr(self, "_all_hooks") and self._all_hooks is not None and len(self._all_hooks) > 0
pipeline_is_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
def module_has_hooks(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False

return hasattr(module, "_hf_hook") and (
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
or hasattr(module._hf_hook, "hooks")
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
)

def module_is_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
return False

return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)

# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
return hasattr(module, "_hf_hook") and module._hf_hook is not None
pipeline_has_hooks = any(module_has_hooks(module) for _, module in self.components.items())
if device and torch.device(device).type == "cuda":
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
if pipeline_is_offloaded:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
if pipeline_is_sequentially_offloaded:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now manually moving the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
# PR: https://github.com/huggingface/accelerate/pull/3223/
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
if pipeline_has_hooks and pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
raise ValueError(
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
if pipeline_is_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
)

# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)

module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
Expand All @@ -452,13 +437,18 @@ def module_is_offloaded(module):
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
)

is_device_mapped = module_has_hooks(module) and hasattr(module, "hf_device_map") and module.hf_device_map is not None

# This can happen for `transformer` models. CPU placement was added in
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
module.to(device, dtype)
if is_device_mapped:
logger.warning(f"{module.__class__.__name__} is has a device map {module.hf_device_map} and will not be moved to {device}.")
else:
module.to(device, dtype)

if (
module.dtype == torch.float16
Expand Down Expand Up @@ -1014,7 +1004,10 @@ def remove_all_hooks(self):
for _, model in self.components.items():
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
accelerate.hooks.remove_hook_from_module(model, recurse=True)
self._all_hooks = []
if hasattr(self, "_all_hooks"):
self._all_hooks = []
if hasattr(self, "_all_sequential_hooks"):
self._all_sequential_hooks = []

def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Expand Down Expand Up @@ -1166,17 +1159,18 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

self._all_sequential_hooks = []
for name, model in self.components.items():
if not isinstance(model, torch.nn.Module):
continue

if name in self._exclude_from_cpu_offload:
model.to(device)
else:
# make sure to offload buffers if not all high level weights
# are of type nn.Module
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)
self.all_sequential_hooks.append(model._hf_hook)

def reset_device_map(self):
r"""
Expand Down

0 comments on commit 583a7e9

Please sign in to comment.