Skip to content

Commit

Permalink
add compute_module_persistent_sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyixuxu committed Dec 23, 2024
1 parent 13c5954 commit 574fe74
Showing 1 changed file with 78 additions and 26 deletions.
104 changes: 78 additions & 26 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import unittest
import unittest.mock as mock
import uuid
from typing import Dict, List, Tuple
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import requests_mock
import torch
from accelerate.utils import compute_module_sizes
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
Expand Down Expand Up @@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
out_queue.join()


def named_persistent_module_tensors(
module: nn.Module,
recurse: bool = False,
):
"""
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
Args:
module (`torch.nn.Module`):
The module we want the tensors on.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
"""
yield from module.named_parameters(recurse=recurse)

for named_buffer in module.named_buffers(recurse=recurse):
name, _ = named_buffer
# Get parent by splitting on dots and traversing the model
parent = module
if "." in name:
parent_name = name.rsplit(".", 1)[0]
for part in parent_name.split("."):
parent = getattr(parent, part)
name = name.split(".")[-1]
if name not in parent._non_persistent_buffers_set:
yield named_buffer


def compute_module_persistent_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)

module_list = []

module_list = named_persistent_module_tensors(model, recurse=True)

for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size

return module_sizes


class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
Expand Down Expand Up @@ -1012,9 +1080,7 @@ def test_cpu_offload(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
model_size = model_size - buffer_size
model_size = compute_module_persistent_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -1044,9 +1110,7 @@ def test_disk_offload_without_safetensors(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
model_size = model_size - buffer_size
model_size = compute_module_persistent_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)

Expand Down Expand Up @@ -1080,9 +1144,7 @@ def test_disk_offload_with_safetensors(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
model_size = model_size - buffer_size
model_size = compute_module_persistent_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

Expand Down Expand Up @@ -1110,9 +1172,7 @@ def test_model_parallelism(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
model_size = model_size - buffer_size
model_size = compute_module_persistent_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -1140,9 +1200,7 @@ def test_sharded_checkpoints(self):

base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
model_size = model_size - buffer_size
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
Expand Down Expand Up @@ -1174,9 +1232,7 @@ def test_sharded_checkpoints_with_variant(self):

base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
model_size = model_size - buffer_size
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -1216,9 +1272,7 @@ def test_sharded_checkpoints_device_map(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
model_size = model_size - buffer_size
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
Expand Down Expand Up @@ -1247,9 +1301,7 @@ def test_variant_sharded_ckpt_right_format(self):
config, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()

model_size = compute_module_sizes(model)[""]
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
model_size = model_size - buffer_size
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down

0 comments on commit 574fe74

Please sign in to comment.