From ffc0eaab6d8ae7176a34ebfff3f225c2e37ba187 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 11:03:04 +0530 Subject: [PATCH] Bump minimum TorchAO version to 0.7.0 (#10293) * bump min torchao version to 0.7.0 * update --- .../quantizers/torchao/torchao_quantizer.py | 5 + src/diffusers/utils/testing_utils.py | 4 +- tests/quantization/torchao/test_torchao.py | 94 +++++++++---------- 3 files changed, 52 insertions(+), 51 deletions(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 8b28a403e6f0..25cd4ad448e7 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -93,6 +93,11 @@ def validate_environment(self, *args, **kwargs): raise ImportError( "Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`" ) + torchao_version = version.parse(importlib.metadata.version("torch")) + if torchao_version < version.parse("0.7.0"): + raise RuntimeError( + f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`." + ) self.offload = False diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 3448b4d28d1f..3ae74cddcbbf 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -490,11 +490,11 @@ def decorator(test_case): return decorator -def require_torchao_version_greater(torchao_version): +def require_torchao_version_greater_or_equal(torchao_version): def decorator(test_case): correct_torchao_version = is_torchao_available() and version.parse( version.parse(importlib.metadata.version("torchao")).base_version - ) > version.parse(torchao_version) + ) >= version.parse(torchao_version) return unittest.skipUnless( correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}." )(test_case) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 6f9980c006ac..418fc997a215 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -36,7 +36,7 @@ nightly, require_torch, require_torch_gpu, - require_torchao_version_greater, + require_torchao_version_greater_or_equal, slow, torch_device, ) @@ -74,13 +74,13 @@ def forward(self, input, *args, **kwargs): if is_torchao_available(): from torchao.dtypes import AffineQuantizedTensor - from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + from torchao.utils import get_model_size_in_bytes @require_torch @require_torch_gpu -@require_torchao_version_greater("0.6.0") +@require_torchao_version_greater_or_equal("0.7.0") class TorchAoConfigTest(unittest.TestCase): def test_to_dict(self): """ @@ -125,7 +125,7 @@ def test_repr(self): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_gpu -@require_torchao_version_greater("0.6.0") +@require_torchao_version_greater_or_equal("0.7.0") class TorchAoTest(unittest.TestCase): def tearDown(self): gc.collect() @@ -139,11 +139,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig): quantization_config=quantization_config, torch_dtype=torch.bfloat16, ) - text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") - text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) + text_encoder_2 = T5EncoderModel.from_pretrained( + model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 + ) tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") - vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -212,7 +214,7 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0): def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]): components = self.get_dummy_components(quantization_config) pipe = FluxPipeline(**components) - pipe.to(device=torch_device, dtype=torch.bfloat16) + pipe.to(device=torch_device) inputs = self.get_dummy_inputs(torch_device) output = pipe(**inputs)[0] @@ -276,7 +278,6 @@ def test_int4wo_quant_bfloat16_conversion(self): self.assertTrue(isinstance(weight, AffineQuantizedTensor)) self.assertEqual(weight.quant_min, 0) self.assertEqual(weight.quant_max, 15) - self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) def test_device_map(self): """ @@ -341,21 +342,33 @@ def test_device_map(self): def test_modules_to_not_convert(self): quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) - quantized_model = FluxTransformer2DModel.from_pretrained( + quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained( "hf-internal-testing/tiny-flux-pipe", subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, ) - unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2] + unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2] self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear)) self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor)) self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16) - quantized_layer = quantized_model.proj_out + quantized_layer = quantized_model_with_not_convert.proj_out self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor)) - self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8) + + quantization_config = TorchAoConfig("int8_weight_only") + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert) + size_quantized = get_model_size_in_bytes(quantized_model) + + self.assertTrue(size_quantized < size_quantized_with_not_convert) def test_training(self): quantization_config = TorchAoConfig("int8_weight_only") @@ -406,23 +419,6 @@ def test_torch_compile(self): # Note: Seems to require higher tolerance self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) - @staticmethod - def _get_memory_footprint(module): - quantized_param_memory = 0.0 - unquantized_param_memory = 0.0 - - for param in module.parameters(): - if param.__class__.__name__ == "AffineQuantizedTensor": - data, scale, zero_point = param.layout_tensor.get_plain() - quantized_param_memory += data.numel() + data.element_size() - quantized_param_memory += scale.numel() + scale.element_size() - quantized_param_memory += zero_point.numel() + zero_point.element_size() - else: - unquantized_param_memory += param.data.numel() * param.data.element_size() - - total_memory = quantized_param_memory + unquantized_param_memory - return total_memory, quantized_param_memory, unquantized_param_memory - def test_memory_footprint(self): r""" A simple test to check if the model conversion has been done correctly by checking on the @@ -433,20 +429,18 @@ def test_memory_footprint(self): transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"] transformer_bf16 = self.get_dummy_components(None)["transformer"] - total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo) - total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint( - transformer_int4wo_gs32 - ) - total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo) - total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16) - - self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16) - # int4wo_gs32 has smaller group size, so more groups -> more scales and zero points - self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32) - # int4 with default group size quantized very few linear layers compared to a smaller group size of 32 - self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32) + total_int4wo = get_model_size_in_bytes(transformer_int4wo) + total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) + total_int8wo = get_model_size_in_bytes(transformer_int8wo) + total_bf16 = get_model_size_in_bytes(transformer_bf16) + + # Latter has smaller group size, so more groups -> more scales and zero points + self.assertTrue(total_int4wo < total_int4wo_gs32) # int8 quantizes more layers compare to int4 with default group size - self.assertTrue(quantized_int8wo < quantized_int4wo) + self.assertTrue(total_int8wo < total_int4wo) + # int4wo does not quantize too many layers because of default group size, but for the layers it does + # there is additional overhead of scales and zero points + self.assertTrue(total_bf16 < total_int4wo) def test_wrong_config(self): with self.assertRaises(ValueError): @@ -456,7 +450,7 @@ def test_wrong_config(self): # This class is not to be run as a test by itself. See the tests that follow this class @require_torch @require_torch_gpu -@require_torchao_version_greater("0.6.0") +@require_torchao_version_greater_or_equal("0.7.0") class TorchAoSerializationTest(unittest.TestCase): model_name = "hf-internal-testing/tiny-flux-pipe" quant_method, quant_method_kwargs = None, None @@ -565,7 +559,7 @@ class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_gpu -@require_torchao_version_greater("0.6.0") +@require_torchao_version_greater_or_equal("0.7.0") @slow @nightly class SlowTorchAoTests(unittest.TestCase): @@ -581,11 +575,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig): quantization_config=quantization_config, torch_dtype=torch.bfloat16, ) - text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") - text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) + text_encoder_2 = T5EncoderModel.from_pretrained( + model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 + ) tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") - vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -617,7 +613,7 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0): def _test_quant_type(self, quantization_config, expected_slice): components = self.get_dummy_components(quantization_config) - pipe = FluxPipeline(**components).to(dtype=torch.bfloat16) + pipe = FluxPipeline(**components) pipe.enable_model_cpu_offload() inputs = self.get_dummy_inputs(torch_device)