From 574fe743bd9a332e2f16cd7e1981171864dc3c6e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 23 Dec 2024 02:05:45 +0100 Subject: [PATCH] add compute_module_persistent_sizes --- tests/models/test_modeling_common.py | 104 ++++++++++++++++++++------- 1 file changed, 78 insertions(+), 26 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index f947fbb9b1b4..4fc14804475a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -22,12 +22,14 @@ import unittest import unittest.mock as mock import uuid -from typing import Dict, List, Tuple +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union import numpy as np import requests_mock import torch -from accelerate.utils import compute_module_sizes +import torch.nn as nn +from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size from huggingface_hub import ModelCard, delete_repo, snapshot_download from huggingface_hub.utils import is_jinja_available from parameterized import parameterized @@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): out_queue.join() +def named_persistent_module_tensors( + module: nn.Module, + recurse: bool = False, +): + """ + A helper function that gathers all the tensors (parameters + persistent buffers) of a given module. + + Args: + module (`torch.nn.Module`): + The module we want the tensors on. + recurse (`bool`, *optional`, defaults to `False`): + Whether or not to go look in every submodule or just return the direct parameters and buffers. + """ + yield from module.named_parameters(recurse=recurse) + + for named_buffer in module.named_buffers(recurse=recurse): + name, _ = named_buffer + # Get parent by splitting on dots and traversing the model + parent = module + if "." in name: + parent_name = name.rsplit(".", 1)[0] + for part in parent_name.split("."): + parent = getattr(parent, part) + name = name.split(".")[-1] + if name not in parent._non_persistent_buffers_set: + yield named_buffer + + +def compute_module_persistent_sizes( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, +): + """ + Compute the size of each submodule of a given model (parameters + persistent buffers). + """ + if dtype is not None: + dtype = _get_proper_dtype(dtype) + dtype_size = dtype_byte_size(dtype) + if special_dtypes is not None: + special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} + special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} + module_sizes = defaultdict(int) + + module_list = [] + + module_list = named_persistent_module_tensors(model, recurse=True) + + for name, tensor in module_list: + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes_size[name] + elif dtype is None: + size = tensor.numel() * dtype_byte_size(tensor.dtype) + elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + # According to the code in set_module_tensor_to_device, these types won't be converted + # so use their original size here + size = tensor.numel() * dtype_byte_size(tensor.dtype) + else: + size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) + name_parts = name.split(".") + for idx in range(len(name_parts) + 1): + module_sizes[".".join(name_parts[:idx])] += size + + return module_sizes + + class ModelUtilsTest(unittest.TestCase): def tearDown(self): super().tearDown() @@ -1012,9 +1080,7 @@ def test_cpu_offload(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] - buffer_size = compute_module_sizes(model, buffers_only=True)[""] - model_size = model_size - buffer_size + model_size = compute_module_persistent_sizes(model)[""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] with tempfile.TemporaryDirectory() as tmp_dir: @@ -1044,9 +1110,7 @@ def test_disk_offload_without_safetensors(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] - buffer_size = compute_module_sizes(model, buffers_only=True)[""] - model_size = model_size - buffer_size + model_size = compute_module_persistent_sizes(model)[""] with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, safe_serialization=False) @@ -1080,9 +1144,7 @@ def test_disk_offload_with_safetensors(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] - buffer_size = compute_module_sizes(model, buffers_only=True)[""] - model_size = model_size - buffer_size + model_size = compute_module_persistent_sizes(model)[""] with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) @@ -1110,9 +1172,7 @@ def test_model_parallelism(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] - buffer_size = compute_module_sizes(model, buffers_only=True)[""] - model_size = model_size - buffer_size + model_size = compute_module_persistent_sizes(model)[""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] with tempfile.TemporaryDirectory() as tmp_dir: @@ -1140,9 +1200,7 @@ def test_sharded_checkpoints(self): base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] - buffer_size = compute_module_sizes(model, buffers_only=True)[""] - model_size = model_size - buffer_size + model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") @@ -1174,9 +1232,7 @@ def test_sharded_checkpoints_with_variant(self): base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] - buffer_size = compute_module_sizes(model, buffers_only=True)[""] - model_size = model_size - buffer_size + model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. variant = "fp16" with tempfile.TemporaryDirectory() as tmp_dir: @@ -1216,9 +1272,7 @@ def test_sharded_checkpoints_device_map(self): torch.manual_seed(0) base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] - buffer_size = compute_module_sizes(model, buffers_only=True)[""] - model_size = model_size - buffer_size + model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") @@ -1247,9 +1301,7 @@ def test_variant_sharded_ckpt_right_format(self): config, _ = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**config).eval() - model_size = compute_module_sizes(model)[""] - buffer_size = compute_module_sizes(model, buffers_only=True)[""] - model_size = model_size - buffer_size + model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. variant = "fp16" with tempfile.TemporaryDirectory() as tmp_dir: