diff --git a/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py b/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py index d466996c6..c4862e5cb 100644 --- a/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +++ b/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py @@ -47,13 +47,15 @@ def compute_kpi_data(in_model: Any, """ + # We assume that the kpi_data API is used to compute the model KPI for mixed precision scenario, + # so we run graph preparation under the assumption of enabled mixed precision. transformed_graph = graph_preparation_runner(in_model, representative_data_gen, core_config.quantization_config, fw_info, fw_impl, tpc, - mixed_precision_enable=core_config.mixed_precision_enable) + mixed_precision_enable=True) # Compute parameters sum weights_params = compute_nodes_weights_params(graph=transformed_graph, fw_info=fw_info) diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py index c44c5a570..6b47b94ec 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py @@ -16,13 +16,11 @@ from typing import List, Callable from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting -from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI class MixedPrecisionQuantizationConfig: def __init__(self, - target_kpi: KPI = None, compute_distance_fn: Callable = None, distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG, num_of_images: int = 32, @@ -36,7 +34,6 @@ def __init__(self, Class with mixed precision parameters to quantize the input model. Args: - target_kpi (KPI): KPI to constraint the search of the mixed-precision configuration for the model. compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer. distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric. num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model. @@ -49,7 +46,6 @@ def __init__(self, """ - self.target_kpi = target_kpi self.compute_distance_fn = compute_distance_fn self.distance_weighting_method = distance_weighting_method self.num_of_images = num_of_images @@ -67,13 +63,21 @@ def __init__(self, self.metric_normalization_threshold = metric_normalization_threshold - def set_target_kpi(self, target_kpi: KPI): + self._mixed_precision_enable = False + + def set_mixed_precision_enable(self): + """ + Set a flag in mixed precision config indicating that mixed precision is enabled. """ - Setting target KPI in mixed precision config. - Args: - target_kpi: A target KPI to set. + self._mixed_precision_enable = True + @property + def mixed_precision_enable(self): """ + A property that indicates whether mixed precision quantization is enabled. - self.target_kpi = target_kpi + Returns: True if mixed precision quantization is enabled + + """ + return self._mixed_precision_enable diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py index 8546bd0d3..c21b08c6c 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py @@ -47,6 +47,7 @@ class BitWidthSearchMethod(Enum): def search_bit_width(graph_to_search_cfg: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation, + target_kpi: KPI, mp_config: MixedPrecisionQuantizationConfig, representative_data_gen: Callable, search_method: BitWidthSearchMethod = BitWidthSearchMethod.INTEGER_PROGRAMMING, @@ -63,6 +64,7 @@ def search_bit_width(graph_to_search_cfg: Graph, graph_to_search_cfg: Graph to search a MP configuration for. fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize). fw_impl: FrameworkImplementation object with specific framework methods implementation. + target_kpi: Target KPI to bound our feasible solution space s.t the configuration does not violate it. mp_config: Mixed-precision quantization configuration. representative_data_gen: Dataset to use for retrieving images for the models inputs. search_method: BitWidthSearchMethod to define which searching method to use. @@ -74,7 +76,6 @@ def search_bit_width(graph_to_search_cfg: Graph, bit-width index on the node). """ - target_kpi = mp_config.target_kpi # target_kpi have to be passed. If it was not passed, the facade is not supposed to get here by now. if target_kpi is None: diff --git a/model_compression_toolkit/core/common/quantization/core_config.py b/model_compression_toolkit/core/common/quantization/core_config.py index aaee6db80..1259a605f 100644 --- a/model_compression_toolkit/core/common/quantization/core_config.py +++ b/model_compression_toolkit/core/common/quantization/core_config.py @@ -30,14 +30,19 @@ def __init__(self, Args: quantization_config (QuantizationConfig): Config for quantization. - mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization (optional, default=None). + mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization. + If None, a default MixedPrecisionQuantizationConfig is used. debug_config (DebugConfig): Config for debugging and editing the network quantization process. """ self.quantization_config = quantization_config - self.mixed_precision_config = mixed_precision_config self.debug_config = debug_config + if mixed_precision_config is None: + self.mixed_precision_config = MixedPrecisionQuantizationConfig() + else: + self.mixed_precision_config = mixed_precision_config + @property def mixed_precision_enable(self): - return self.mixed_precision_config is not None + return self.mixed_precision_config is not None and self.mixed_precision_config.mixed_precision_enable diff --git a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py index 41e141f76..0df28ed74 100644 --- a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py @@ -71,7 +71,7 @@ def set_quantization_configs_to_node(node: BaseNode, quant_config: Quantization configuration to generate the node's configurations from. fw_info: Information needed for quantization about the specific framework. tpc: TargetPlatformCapabilities to get default OpQuantizationConfig. - mixed_precision_enable: is mixed precision enabled + mixed_precision_enable: is mixed precision enabled. """ node_qc_options = node.get_qco(tpc) diff --git a/model_compression_toolkit/core/graph_prep_runner.py b/model_compression_toolkit/core/graph_prep_runner.py index 397405889..28b1b6a78 100644 --- a/model_compression_toolkit/core/graph_prep_runner.py +++ b/model_compression_toolkit/core/graph_prep_runner.py @@ -57,7 +57,8 @@ def graph_preparation_runner(in_model: Any, fw_impl: FrameworkImplementation object with a specific framework methods implementation. tpc: TargetPlatformCapabilities object that models the inference target platform and the attached framework operator's information. - tb_w: TensorboardWriter object for logging + tb_w: TensorboardWriter object for logging. + mixed_precision_enable: is mixed precision enabled. Returns: An internal graph representation of the input model. @@ -103,7 +104,7 @@ def get_finalized_graph(initial_graph: Graph, kernel channels indices, groups of layers by how they should be quantized, etc.) tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc. fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation. - mixed_precision_enable: is mixed precision enabled. + mixed_precision_enable: is mixed precision enabled. Returns: Graph object that represents the model, after applying all required modifications to it. """ diff --git a/model_compression_toolkit/core/pytorch/kpi_data_facade.py b/model_compression_toolkit/core/pytorch/kpi_data_facade.py index faa99ff47..06e10e899 100644 --- a/model_compression_toolkit/core/pytorch/kpi_data_facade.py +++ b/model_compression_toolkit/core/pytorch/kpi_data_facade.py @@ -38,7 +38,7 @@ def pytorch_kpi_data(in_model: Module, representative_data_gen: Callable, - core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()), + core_config: CoreConfig = CoreConfig(), target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC) -> KPI: """ Computes KPI data that can be used to calculate the desired target KPI for mixed-precision quantization. diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 95285a029..94b6d1aec 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -47,6 +47,7 @@ def core_runner(in_model: Any, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation, tpc: TargetPlatformCapabilities, + target_kpi: KPI = None, tb_w: TensorboardWriter = None): """ Quantize a trained model using post-training quantization. @@ -66,6 +67,7 @@ def core_runner(in_model: Any, fw_impl: FrameworkImplementation object with a specific framework methods implementation. tpc: TargetPlatformCapabilities object that models the inference target platform and the attached framework operator's information. + target_kpi: KPI to constraint the search of the mixed-precision configuration for the model. tb_w: TensorboardWriter object for logging Returns: @@ -81,6 +83,13 @@ def core_runner(in_model: Any, Logger.warning('representative_data_gen generates a batch size of 1 which can be slow for optimization:' ' consider increasing the batch size') + # Checking whether to run mixed precision quantization + if target_kpi is not None: + if core_config.mixed_precision_config is None: + Logger.critical("Provided an initialized target_kpi, that means that mixed precision quantization is " + "enabled, but the provided MixedPrecisionQuantizationConfig is None.") + core_config.mixed_precision_config.set_mixed_precision_enable() + graph = graph_preparation_runner(in_model, representative_data_gen, core_config.quantization_config, @@ -105,13 +114,12 @@ def core_runner(in_model: Any, # Finalize bit widths ###################################### if core_config.mixed_precision_enable: - if core_config.mixed_precision_config.target_kpi is None: - Logger.critical(f"Trying to run Mixed Precision quantization without providing a valid target KPI.") if core_config.mixed_precision_config.configuration_overwrite is None: bit_widths_config = search_bit_width(tg, fw_info, fw_impl, + target_kpi, core_config.mixed_precision_config, representative_data_gen, hessian_info_service=hessian_info_service) diff --git a/model_compression_toolkit/gptq/keras/quantization_facade.py b/model_compression_toolkit/gptq/keras/quantization_facade.py index 6c9ed6dac..d35d99e14 100644 --- a/model_compression_toolkit/gptq/keras/quantization_facade.py +++ b/model_compression_toolkit/gptq/keras/quantization_facade.py @@ -116,6 +116,7 @@ def get_keras_gptq_config(n_epochs: int, def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable, gptq_config: GradientPTQConfig, gptq_representative_data_gen: Callable = None, + target_kpi: KPI = None, core_config: CoreConfig = CoreConfig(), target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]: """ @@ -139,6 +140,7 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da representative_data_gen (Callable): Dataset used for calibration. gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer). gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen + target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. @@ -166,6 +168,12 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da >>> config = mct.core.CoreConfig() + If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model + with different bitwidths for different layers. + The candidates bitwidth for quantization should be defined in the target platform model: + + >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1)) + For mixed-precision set a target KPI object: Create a KPI object to limit our returned model's size. Note that this value affects only coefficients that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value, @@ -173,19 +181,13 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da >>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits. - If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model - with different bitwidths for different layers. - The candidates bitwidth for quantization should be defined in the target platform model: - - >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, target_kpi=kpi)) - Create GPTQ config: >>> gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=1) Pass the model with the representative dataset generator to get a quantized model: - >>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model, repr_datagen, gptq_config, core_config=config) + >>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config) """ KerasModelValidation(model=in_model, diff --git a/model_compression_toolkit/gptq/pytorch/quantization_facade.py b/model_compression_toolkit/gptq/pytorch/quantization_facade.py index 0c39b57f1..2c1692a51 100644 --- a/model_compression_toolkit/gptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/gptq/pytorch/quantization_facade.py @@ -94,6 +94,7 @@ def get_pytorch_gptq_config(n_epochs: int, def pytorch_gradient_post_training_quantization(model: Module, representative_data_gen: Callable, + target_kpi: KPI = None, core_config: CoreConfig = CoreConfig(), gptq_config: GradientPTQConfig = None, gptq_representative_data_gen: Callable = None, @@ -117,6 +118,7 @@ def pytorch_gradient_post_training_quantization(model: Module, Args: model (Module): Pytorch model to quantize. representative_data_gen (Callable): Dataset used for calibration. + target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer). gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen @@ -174,6 +176,7 @@ def pytorch_gradient_post_training_quantization(model: Module, fw_info=DEFAULT_PYTORCH_INFO, fw_impl=fw_impl, tpc=target_platform_capabilities, + target_kpi=target_kpi, tb_w=tb_w) # ---------------------- # diff --git a/model_compression_toolkit/ptq/keras/quantization_facade.py b/model_compression_toolkit/ptq/keras/quantization_facade.py index db0d7d0ac..33dce3e7c 100644 --- a/model_compression_toolkit/ptq/keras/quantization_facade.py +++ b/model_compression_toolkit/ptq/keras/quantization_facade.py @@ -42,6 +42,7 @@ def keras_post_training_quantization(in_model: Model, representative_data_gen: Callable, + target_kpi: KPI = None, core_config: CoreConfig = CoreConfig(), target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC): """ @@ -60,6 +61,7 @@ def keras_post_training_quantization(in_model: Model, Args: in_model (Model): Keras model to quantize. representative_data_gen (Callable): Dataset used for calibration. + target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. @@ -135,6 +137,7 @@ def keras_post_training_quantization(in_model: Model, fw_info=fw_info, fw_impl=fw_impl, tpc=target_platform_capabilities, + target_kpi=target_kpi, tb_w=tb_w) tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w) diff --git a/model_compression_toolkit/ptq/pytorch/quantization_facade.py b/model_compression_toolkit/ptq/pytorch/quantization_facade.py index 436221da4..1fbcab16e 100644 --- a/model_compression_toolkit/ptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/ptq/pytorch/quantization_facade.py @@ -41,6 +41,7 @@ def pytorch_post_training_quantization(in_module: Module, representative_data_gen: Callable, + target_kpi: KPI = None, core_config: CoreConfig = CoreConfig(), target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC): """ @@ -59,6 +60,7 @@ def pytorch_post_training_quantization(in_module: Module, Args: in_module (Module): Pytorch module to quantize. representative_data_gen (Callable): Dataset used for calibration. + target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to. @@ -107,6 +109,7 @@ def pytorch_post_training_quantization(in_module: Module, fw_info=DEFAULT_PYTORCH_INFO, fw_impl=fw_impl, tpc=target_platform_capabilities, + target_kpi=target_kpi, tb_w=tb_w) tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w) diff --git a/model_compression_toolkit/qat/keras/quantization_facade.py b/model_compression_toolkit/qat/keras/quantization_facade.py index 0e10c93d3..f00b2374d 100644 --- a/model_compression_toolkit/qat/keras/quantization_facade.py +++ b/model_compression_toolkit/qat/keras/quantization_facade.py @@ -87,6 +87,7 @@ def qat_wrapper(n: common.BaseNode, def keras_quantization_aware_training_init_experimental(in_model: Model, representative_data_gen: Callable, + target_kpi: KPI = None, core_config: CoreConfig = CoreConfig(), qat_config: QATConfig = QATConfig(), target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC): @@ -108,6 +109,7 @@ def keras_quantization_aware_training_init_experimental(in_model: Model, Args: in_model (Model): Keras model to quantize. representative_data_gen (Callable): Dataset used for initial calibration. + target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. qat_config (QATConfig): QAT configuration target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. @@ -157,7 +159,7 @@ def keras_quantization_aware_training_init_experimental(in_model: Model, Pass the model, the representative dataset generator, the configuration and the target KPI to get a quantized model: - >>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model, repr_datagen, kpi, core_config=core_config) + >>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model, repr_datagen, kpi, core_config=config) Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary: @@ -191,6 +193,7 @@ def keras_quantization_aware_training_init_experimental(in_model: Model, fw_info=DEFAULT_KERAS_INFO, fw_impl=fw_impl, tpc=target_platform_capabilities, + target_kpi=target_kpi, tb_w=tb_w) tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_KERAS_INFO, fw_impl, tb_w) diff --git a/model_compression_toolkit/qat/pytorch/quantization_facade.py b/model_compression_toolkit/qat/pytorch/quantization_facade.py index 6f44c919a..c78b541c4 100644 --- a/model_compression_toolkit/qat/pytorch/quantization_facade.py +++ b/model_compression_toolkit/qat/pytorch/quantization_facade.py @@ -75,6 +75,7 @@ def qat_wrapper(n: common.BaseNode, def pytorch_quantization_aware_training_init_experimental(in_model: Module, representative_data_gen: Callable, + target_kpi: KPI = None, core_config: CoreConfig = CoreConfig(), qat_config: QATConfig = QATConfig(), target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC): @@ -96,6 +97,7 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module, Args: in_model (Model): Pytorch model to quantize. representative_data_gen (Callable): Dataset used for initial calibration. + target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. qat_config (QATConfig): QAT configuration target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Pytorch model according to. @@ -158,6 +160,7 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module, fw_info=DEFAULT_PYTORCH_INFO, fw_impl=fw_impl, tpc=target_platform_capabilities, + target_kpi=target_kpi, tb_w=tb_w) tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w) diff --git a/tests/common_tests/base_feature_test.py b/tests/common_tests/base_feature_test.py index 7ec66ed0d..64fcfe473 100644 --- a/tests/common_tests/base_feature_test.py +++ b/tests/common_tests/base_feature_test.py @@ -43,6 +43,7 @@ def run_test(self): core_config = self.get_core_config() ptq_model, quantization_info = self.get_ptq_facade()(model_float, self.representative_data_gen_experimental, + target_kpi=self.get_kpi(), core_config=core_config, target_platform_capabilities=self.get_tpc() ) diff --git a/tests/common_tests/helpers/prep_graph_for_func_test.py b/tests/common_tests/helpers/prep_graph_for_func_test.py index 338dd9e6a..e7b14c7d4 100644 --- a/tests/common_tests/helpers/prep_graph_for_func_test.py +++ b/tests/common_tests/helpers/prep_graph_for_func_test.py @@ -93,6 +93,7 @@ def prepare_graph_with_quantization_parameters(in_model, def prepare_graph_set_bit_widths(in_model, fw_impl, representative_data_gen, + target_kpi, n_iter, quant_config, fw_info, @@ -107,6 +108,9 @@ def prepare_graph_set_bit_widths(in_model, debug_config=DebugConfig(analyze_similarity=analyze_similarity, network_editor=network_editor)) + if target_kpi is not None: + core_config.mixed_precision_config.set_mixed_precision_enable() + tb_w = init_tensorboard_writer(fw_info) # convert old representative dataset generation to a generator @@ -133,12 +137,13 @@ def _representative_data_gen(): # Finalize bit widths ###################################### if core_config.mixed_precision_enable: - assert core_config.mixed_precision_config.target_kpi is not None + if core_config.mixed_precision_config.configuration_overwrite is None: bit_widths_config = search_bit_width(tg, fw_info, fw_impl, + target_kpi, core_config.mixed_precision_config, _representative_data_gen) else: diff --git a/tests/keras_tests/custom_layers_tests/test_sony_ssd_postprocess_layer.py b/tests/keras_tests/custom_layers_tests/test_sony_ssd_postprocess_layer.py index 6fe8181fc..6e7f8d68d 100644 --- a/tests/keras_tests/custom_layers_tests/test_sony_ssd_postprocess_layer.py +++ b/tests/keras_tests/custom_layers_tests/test_sony_ssd_postprocess_layer.py @@ -49,11 +49,11 @@ def test_custom_layer(self): core_config = mct.core.CoreConfig( mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig( - use_hessian_based_scores=False, - target_kpi=mct.core.KPI(weights_memory=6000))) + use_hessian_based_scores=False)) q_model, _ = mct.ptq.keras_post_training_quantization(model, get_rep_dataset(2, (1, 8, 8, 3)), - core_config=core_config) + core_config=core_config, + target_kpi=mct.core.KPI(weights_memory=6000)) # verify the custom layer is in the quantized model self.assertTrue(isinstance(q_model.layers[-1], SSDPostProcess), 'Custom layer should be in the quantized model') diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py index f02bd0be4..73868df77 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py @@ -114,14 +114,18 @@ def representative_data_gen(): ptq_model, quantization_info = mct.ptq.keras_post_training_quantization( model_float, representative_data_gen, + target_kpi=self.get_kpi(), + core_config=core_config, + target_platform_capabilities=tpc + ) + ptq_gptq_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization( + model_float, + representative_data_gen, + gptq_config=self.get_gptq_config(), + target_kpi=self.get_kpi(), core_config=core_config, target_platform_capabilities=tpc ) - ptq_gptq_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model_float, - representative_data_gen, - gptq_config=self.get_gptq_config(), - core_config=core_config, - target_platform_capabilities=tpc) self.compare(ptq_model, ptq_gptq_model, input_x=x, quantization_info=quantization_info) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_bops_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_bops_test.py index 718e4a405..010c52083 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_bops_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_bops_test.py @@ -47,8 +47,7 @@ def get_tpc(self): name="mp_bopts_test") def get_mixed_precision_config(self): - return MixedPrecisionQuantizationConfig(num_of_images=1, - target_kpi=self.get_kpi()) + return MixedPrecisionQuantizationConfig(num_of_images=1) def get_input_shapes(self): return [[self.val_batch_size, 16, 16, 3]] diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py index 7dd7710d7..baccc1fdb 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py @@ -65,8 +65,7 @@ def get_quantization_config(self): activation_channel_equalization=False) def get_mixed_precision_config(self): - return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, - target_kpi=self.get_kpi()) + return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1) def get_input_shapes(self): return [[self.val_batch_size, 16, 16, 3]] @@ -424,8 +423,7 @@ def get_quantization_config(self): input_scaling=False, activation_channel_equalization=False) def get_mixed_precision_config(self): - return mct.core.MixedPrecisionQuantizationConfig(num_of_images=self.num_of_inputs, - target_kpi=self.get_kpi()) + return mct.core.MixedPrecisionQuantizationConfig(num_of_images=self.num_of_inputs) def create_networks(self): inputs_1 = layers.Input(shape=self.get_input_shapes()[0][1:]) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/edit_qc_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/edit_qc_test.py index 2232b3115..2760607f4 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/edit_qc_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/edit_qc_test.py @@ -45,6 +45,9 @@ def prepare_graph_for_first_network_editor(in_model, representative_data_gen, core_config, fw_info, fw_impl, tpc, target_kpi=None, tb_w=None): + if target_kpi is not None: + core_config.mixed_precision_enable.set_mixed_precision_enable() + transformed_graph = graph_preparation_runner(in_model, representative_data_gen, core_config.quantization_config, @@ -91,6 +94,10 @@ def prepare_graph_for_second_network_editor(in_model, representative_data_gen, c tpc=tpc, target_kpi=target_kpi, tb_w=tb_w) + + if target_kpi is not None: + core_config.mixed_precision_enable.set_mixed_precision_enable() + ###################################### # Calculate quantization params ###################################### @@ -133,13 +140,14 @@ def prepare_graph_for_second_network_editor(in_model, representative_data_gen, c ###################################### # Finalize bit widths ###################################### - if core_config.mixed_precision_enable: - assert target_kpi is not None + if target_kpi is not None: + assert core_config.mixed_precision_enable if core_config.mixed_precision_config.configuration_overwrite is None: bit_widths_config = search_bit_width(tg_with_bias, fw_info, fw_impl, + target_kpi, core_config.mixed_precision_config, representative_data_gen) else: diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py index be08d4363..de1dfc70a 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py @@ -291,12 +291,12 @@ def __init__(self, unit_test, kpi_weights=np.inf, kpi_activation=np.inf, expecte def run_test(self, **kwargs): model_float = self.create_networks() - config = mct.core.CoreConfig( - mixed_precision_config=MixedPrecisionQuantizationConfig(target_kpi= - mct.core.KPI(weights_memory=self.kpi_weights, - activation_memory=self.kpi_activation))) + config = mct.core.CoreConfig() qat_ready_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental( - model_float, self.representative_data_gen_experimental, core_config=config, + model_float, + self.representative_data_gen_experimental, + mct.core.KPI(weights_memory=self.kpi_weights, activation_memory=self.kpi_activation), + core_config=config, target_platform_capabilities=self.get_tpc()) self.compare(qat_ready_model, quantization_info) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/reused_layer_mixed_precision_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/reused_layer_mixed_precision_test.py index c2a1d13b2..c2b2db035 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/reused_layer_mixed_precision_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/reused_layer_mixed_precision_test.py @@ -52,7 +52,7 @@ def get_quantization_config(self): input_scaling=True, activation_channel_equalization=True) def get_mixed_precision_config(self): - return MixedPrecisionQuantizationConfig(target_kpi=self.get_kpi()) + return MixedPrecisionQuantizationConfig() def create_networks(self): layer = layers.Conv2D(3, 4) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py index f255d1e30..0d2033115 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py @@ -45,8 +45,7 @@ def get_quantization_config(self): input_scaling=True, activation_channel_equalization=True) def get_mixed_precision_config(self): - return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, - target_kpi=self.get_kpi()) + return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1) def get_input_shapes(self): return [[self.val_batch_size, 16, 16, 3]] @@ -109,8 +108,7 @@ def get_kpi(self): def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, - distance_weighting_method=self.distance_metric, - target_kpi=self.get_kpi()) + distance_weighting_method=self.distance_metric) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) @@ -234,7 +232,6 @@ def __init__(self, unit_test): def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, - target_kpi=self.get_kpi(), use_hessian_based_scores=False) def get_kpi(self): @@ -375,7 +372,7 @@ def get_quantization_config(self): input_scaling=False, activation_channel_equalization=False) def get_mixed_precision_config(self): - return mct.core.MixedPrecisionQuantizationConfig(target_kpi=self.get_kpi()) + return mct.core.MixedPrecisionQuantizationConfig() class MixedPrecisionActivationDisabled(MixedPercisionBaseTest): @@ -391,8 +388,7 @@ def get_quantization_config(self): activation_channel_equalization=False) def get_mixed_precision_config(self): - return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, - target_kpi=self.get_kpi()) + return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1) def get_tpc(self): base_config, _, default_config = get_op_quantization_configs() @@ -425,7 +421,6 @@ def __init__(self, unit_test): def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, - target_kpi=self.get_kpi(), distance_weighting_method=MpDistanceWeighting.LAST_LAYER, use_hessian_based_scores=False) diff --git a/tests/keras_tests/function_tests/test_kpi_data.py b/tests/keras_tests/function_tests/test_kpi_data.py index 25ec1a132..83f364130 100644 --- a/tests/keras_tests/function_tests/test_kpi_data.py +++ b/tests/keras_tests/function_tests/test_kpi_data.py @@ -93,9 +93,9 @@ def prep_test(model, mp_bitwidth_candidates_list, random_datagen): mp_bitwidth_candidates_list=mp_bitwidth_candidates_list, name="kpi_data_test") - kpi_data = mct.core.keras_kpi_data(in_model=model, representative_data_gen=random_datagen, - core_config=mct.core.CoreConfig( - mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig()), + kpi_data = mct.core.keras_kpi_data(in_model=model, + representative_data_gen=random_datagen, + core_config=mct.core.CoreConfig(), target_platform_capabilities=tpc) return kpi_data diff --git a/tests/keras_tests/models_tests/test_networks_runner.py b/tests/keras_tests/models_tests/test_networks_runner.py index e0c45b084..5c3e5dc07 100644 --- a/tests/keras_tests/models_tests/test_networks_runner.py +++ b/tests/keras_tests/models_tests/test_networks_runner.py @@ -107,11 +107,12 @@ def representative_data_gen(): learning_rate=0.0001), optimizer_rest=tf.keras.optimizers.Adam( learning_rate=0.0001), loss=multiple_tensors_mse_loss) - ptq_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(self.model_float, - representative_data_gen, - gptq_config=arc, - core_config=core_config, - target_platform_capabilities=tpc) + ptq_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization( + self.model_float, + representative_data_gen, + core_config=core_config, + gptq_config=arc, + target_platform_capabilities=tpc) else: ptq_model, quantization_info = mct.ptq.keras_post_training_quantization(self.model_float, representative_data_gen, diff --git a/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py b/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py index 2f9371a5e..1baede0e1 100644 --- a/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py +++ b/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py @@ -256,11 +256,11 @@ def rep_data(): rep_data, target_platform_capabilities=tpc) core_config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=2, - use_hessian_based_scores=False, - target_kpi=mct.core.KPI(np.inf))) + use_hessian_based_scores=False)) quantized_model, _ = mct.ptq.keras_post_training_quantization(model, rep_data, core_config=core_config, + target_kpi=mct.core.KPI(np.inf), target_platform_capabilities=tpc) def test_get_keras_supported_version(self): diff --git a/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py index f2d3929ad..918c009f0 100644 --- a/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py +++ b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py @@ -227,7 +227,7 @@ def dummy_representative_dataset(): graph.set_tpc(tpc) graph = set_quantization_configuration_to_graph(graph=graph, quant_config=core_config.quantization_config, - mixed_precision_enable=core_config.mixed_precision_enable) + mixed_precision_enable=True) for node in graph.nodes: node.prior_info = keras_impl.get_node_prior_info(node=node, @@ -255,6 +255,7 @@ def representative_data_gen(): cfg = search_bit_width(graph_to_search_cfg=graph, fw_info=DEFAULT_KERAS_INFO, fw_impl=keras_impl, + target_kpi=KPI(np.inf), mp_config=core_config.mixed_precision_config, representative_data_gen=representative_data_gen, search_method=BitWidthSearchMethod.INTEGER_PROGRAMMING) @@ -263,23 +264,23 @@ def representative_data_gen(): cfg = search_bit_width(graph_to_search_cfg=graph, fw_info=DEFAULT_KERAS_INFO, fw_impl=keras_impl, + target_kpi=KPI(np.inf), mp_config=core_config.mixed_precision_config, representative_data_gen=representative_data_gen, search_method=None) - core_config.mixed_precision_config.target_kpi = None with self.assertRaises(Exception): cfg = search_bit_width(graph_to_search_cfg=graph, fw_info=DEFAULT_KERAS_INFO, fw_impl=keras_impl, + target_kpi=None, mp_config=core_config.mixed_precision_config, representative_data_gen=representative_data_gen, search_method=BitWidthSearchMethod.INTEGER_PROGRAMMING) def test_mixed_precision_search_facade(self): core_config_avg_weights = CoreConfig(quantization_config=DEFAULTCONFIG, - mixed_precision_config=MixedPrecisionQuantizationConfig(KPI(np.inf), - compute_mse, + mixed_precision_config=MixedPrecisionQuantizationConfig(compute_mse, MpDistanceWeighting.AVG, num_of_images=1, use_hessian_based_scores=False)) @@ -287,8 +288,7 @@ def test_mixed_precision_search_facade(self): self.run_search_bitwidth_config_test(core_config_avg_weights) core_config_last_layer = CoreConfig(quantization_config=DEFAULTCONFIG, - mixed_precision_config=MixedPrecisionQuantizationConfig(KPI(np.inf), - compute_mse, + mixed_precision_config=MixedPrecisionQuantizationConfig(compute_mse, MpDistanceWeighting.LAST_LAYER, num_of_images=1, use_hessian_based_scores=False)) diff --git a/tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py b/tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py index 143d78776..9a4f132d3 100644 --- a/tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py +++ b/tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py @@ -110,8 +110,7 @@ def plot_tensor_sizes(self): cfg = mct.core.DEFAULTCONFIG mp_cfg = mct.core.MixedPrecisionQuantizationConfig(compute_distance_fn=compute_mse, distance_weighting_method=MpDistanceWeighting.AVG, - use_hessian_based_scores=False, - target_kpi=mct.core.KPI(np.inf)) + use_hessian_based_scores=False) # compare max tensor size with plotted max tensor size tg = prepare_graph_set_bit_widths(in_model=model, @@ -121,6 +120,7 @@ def plot_tensor_sizes(self): tpc=tpc, network_editor=[], quant_config=cfg, + target_kpi=mct.core.KPI(), n_iter=1, analyze_similarity=True, mp_cfg=mp_cfg) @@ -145,11 +145,11 @@ def rep_data(): yield [np.random.randn(1, 8, 8, 3)] mp_qc = mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, - use_hessian_based_scores=False, - target_kpi=mct.core.KPI(np.inf)) + use_hessian_based_scores=False) core_config = mct.core.CoreConfig(mixed_precision_config=mp_qc) quantized_model, _ = mct.ptq.keras_post_training_quantization(self.model, rep_data, + target_kpi=mct.core.KPI(np.inf), core_config=core_config, target_platform_capabilities=tpc) @@ -162,6 +162,7 @@ def rep_data(): self.model = MultipleOutputsNet() quantized_model, _ = mct.ptq.keras_post_training_quantization(self.model, rep_data, + target_kpi=mct.core.KPI(np.inf), core_config=core_config, target_platform_capabilities=tpc) diff --git a/tests/pytorch_tests/function_tests/kpi_data_test.py b/tests/pytorch_tests/function_tests/kpi_data_test.py index 75475a308..1130e4bb0 100644 --- a/tests/pytorch_tests/function_tests/kpi_data_test.py +++ b/tests/pytorch_tests/function_tests/kpi_data_test.py @@ -112,9 +112,9 @@ def prep_test(model, mp_bitwidth_candidates_list, random_datagen): test_name='kpi_data_test', tpc_name='kpi_data_test') - kpi_data = mct.core.pytorch_kpi_data(in_model=model, representative_data_gen=random_datagen, - core_config=mct.core.CoreConfig( - mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig()), + kpi_data = mct.core.pytorch_kpi_data(in_model=model, + representative_data_gen=random_datagen, + core_config=mct.core.CoreConfig(), target_platform_capabilities=tpc_dict['kpi_data_test']) return kpi_data diff --git a/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py b/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py index 77defb655..2854cb94c 100644 --- a/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py +++ b/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py @@ -240,12 +240,13 @@ def rep_data(): rep_data, target_platform_capabilities=tpc) - mp_qc = MixedPrecisionQuantizationConfig(target_kpi=mct.core.KPI(np.inf)) + mp_qc = MixedPrecisionQuantizationConfig() mp_qc.num_of_images = 1 core_config = mct.core.CoreConfig(quantization_config=mct.core.QuantizationConfig(), mixed_precision_config=mp_qc) quantized_model, _ = mct.ptq.pytorch_post_training_quantization(model, rep_data, + target_kpi=mct.core.KPI(np.inf), target_platform_capabilities=tpc, core_config=core_config) diff --git a/tests/pytorch_tests/model_tests/base_pytorch_test.py b/tests/pytorch_tests/model_tests/base_pytorch_test.py index 35fd7a2ae..c1ebb5a22 100644 --- a/tests/pytorch_tests/model_tests/base_pytorch_test.py +++ b/tests/pytorch_tests/model_tests/base_pytorch_test.py @@ -67,7 +67,8 @@ def get_core_configs(self): base_quant_config = mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING, mct.core.QuantizationErrorMethod.NOCLIPPING, False, True) base_core_config = mct.core.CoreConfig(quantization_config=base_quant_config, - debug_config=self.get_debug_config()) + mixed_precision_config=self.get_mixed_precision_config(), + debug_config=self.get_debug_config()) return { 'no_quantization': base_core_config, 'all_32bit': base_core_config, @@ -145,6 +146,7 @@ def representative_data_gen_experimental(): ptq_model, quantization_info = mct.ptq.pytorch_post_training_quantization(in_module=model_float, representative_data_gen=representative_data_gen_experimental, + target_kpi=self.get_kpi(), core_config=core_config, target_platform_capabilities=tpc) diff --git a/tests/pytorch_tests/model_tests/feature_models/dynamic_size_inputs_test.py b/tests/pytorch_tests/model_tests/feature_models/dynamic_size_inputs_test.py index d95292a41..25f03b963 100644 --- a/tests/pytorch_tests/model_tests/feature_models/dynamic_size_inputs_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/dynamic_size_inputs_test.py @@ -120,6 +120,7 @@ def representative_data_gen_experimental(): ptq_model, quantization_info = mct.ptq.pytorch_post_training_quantization( in_module=model_float, representative_data_gen=representative_data_gen_experimental, + target_kpi=self.get_kpi(), core_config=core_config, target_platform_capabilities=tpc ) diff --git a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py index ef7dd51cc..41e76a7d2 100644 --- a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py @@ -49,8 +49,7 @@ def get_core_configs(self): qc = mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE, relu_bound_to_power_of_2=False, weights_bias_correction=True, input_scaling=False, activation_channel_equalization=False) - mpc = mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, - target_kpi=self.get_kpi()) + mpc = mct.core.MixedPrecisionQuantizationConfig(num_of_images=1) return {"mixed_precision_activation_model": mct.core.CoreConfig(quantization_config=qc, mixed_precision_config=mpc)} @@ -128,8 +127,7 @@ def get_kpi(self): return KPI(np.inf, np.inf) def get_mixed_precision_config(self): - return MixedPrecisionQuantizationConfig(num_of_images=4, - target_kpi=self.get_kpi()) + return MixedPrecisionQuantizationConfig(num_of_images=4) def create_feature_network(self, input_shape): return MixedPrecisionMultipleInputsNet(input_shape) diff --git a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_bops_test.py b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_bops_test.py index 4807051c8..b02ed1ad0 100644 --- a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_bops_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_bops_test.py @@ -117,7 +117,7 @@ def get_core_configs(self): qc = mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE, relu_bound_to_power_of_2=False, weights_bias_correction=True, input_scaling=False, activation_channel_equalization=False) - mpc = MixedPrecisionQuantizationConfig(num_of_images=1, target_kpi=self.get_kpi()) + mpc = MixedPrecisionQuantizationConfig(num_of_images=1) return {"mixed_precision_bops_model": mct.core.CoreConfig(quantization_config=qc, mixed_precision_config=mpc)} diff --git a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_weights_test.py b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_weights_test.py index dd5c8089f..0e6ebf2d6 100644 --- a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_weights_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_weights_test.py @@ -48,7 +48,7 @@ def get_core_configs(self): qc = mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE, relu_bound_to_power_of_2=False, weights_bias_correction=True, input_scaling=False, activation_channel_equalization=False) - mpc = mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, target_kpi=self.get_kpi()) + mpc = mct.core.MixedPrecisionQuantizationConfig(num_of_images=1) return {"mixed_precision_model": mct.core.CoreConfig(quantization_config=qc, mixed_precision_config=mpc)} @@ -92,8 +92,7 @@ def get_core_configs(self): relu_bound_to_power_of_2=False, weights_bias_correction=True, input_scaling=False, activation_channel_equalization=False) mpc = mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, - distance_weighting_method=self.distance_metric, - target_kpi=self.get_kpi()) + distance_weighting_method=self.distance_metric) return {"mixed_precision_model": mct.core.CoreConfig(quantization_config=qc, mixed_precision_config=mpc)} diff --git a/tests/pytorch_tests/model_tests/feature_models/qat_test.py b/tests/pytorch_tests/model_tests/feature_models/qat_test.py index 7d96db950..fb6fdcc92 100644 --- a/tests/pytorch_tests/model_tests/feature_models/qat_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/qat_test.py @@ -261,10 +261,11 @@ def get_tpc(self): def run_test(self): self._gen_fixed_input() model_float = self.create_networks() + config = mct.core.CoreConfig() kpi = mct.core.KPI() # inf memory - config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig(target_kpi=kpi)) qat_ready_model, quantization_info = mct.qat.pytorch_quantization_aware_training_init_experimental(model_float, self.representative_data_gen_experimental, + kpi, core_config=config, target_platform_capabilities=self.get_tpc()) @@ -305,10 +306,11 @@ def get_tpc(self): def run_test(self): self._gen_fixed_input() model_float = self.create_networks() + config = mct.core.CoreConfig() kpi = mct.core.KPI(weights_memory=50, activation_memory=40) - config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig(target_kpi=kpi)) qat_ready_model, quantization_info = mct.qat.pytorch_quantization_aware_training_init_experimental(model_float, self.representative_data_gen_experimental, + kpi, core_config=config, target_platform_capabilities=self.get_tpc()) diff --git a/tutorials/notebooks/keras/gptq/example_keras_mobilenet_gptq_mixed_precision.py b/tutorials/notebooks/keras/gptq/example_keras_mobilenet_gptq_mixed_precision.py index bb50c0f94..cf0521225 100644 --- a/tutorials/notebooks/keras/gptq/example_keras_mobilenet_gptq_mixed_precision.py +++ b/tutorials/notebooks/keras/gptq/example_keras_mobilenet_gptq_mixed_precision.py @@ -135,7 +135,6 @@ def representative_data_gen() -> list: # examples: # weights_compression_ratio = 0.75 - About 0.75 of the model's weights memory size when quantized with 8 bits. kpi = mct.core.KPI(kpi_data.weights_memory * args.weights_compression_ratio) - config.mixed_precision_config.set_target_kpi(kpi) # Create a GPTQ quantization configuration and set the number of training iterations. gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=args.num_gptq_training_iterations, @@ -145,7 +144,8 @@ def representative_data_gen() -> list: representative_data_gen, gptq_config=gptq_config, core_config=config, - target_platform_capabilities=target_platform_cap) + target_platform_capabilities=target_platform_cap, + target_kpi=kpi) # Export quantized model to TFLite and Keras. # For more details please see: https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/exporter/README.md diff --git a/tutorials/notebooks/keras/ptq/example_keras_effdet_lite0.ipynb b/tutorials/notebooks/keras/ptq/example_keras_effdet_lite0.ipynb index ce9a52a6d..20b269a39 100644 --- a/tutorials/notebooks/keras/ptq/example_keras_effdet_lite0.ipynb +++ b/tutorials/notebooks/keras/ptq/example_keras_effdet_lite0.ipynb @@ -387,12 +387,12 @@ "# set weights memory size, so the quantized model will fit the IMX500 memory\n", "kpi = mct.core.KPI(weights_memory=2674291)\n", "# set MixedPrecision configuration for compressing the weights\n", - "mp_config = mct.core.MixedPrecisionQuantizationConfig(use_hessian_based_scores=False,\n", - " target_kpi=kpi)\n", + "mp_config = mct.core.MixedPrecisionQuantizationConfig(use_hessian_based_scores=False)\n", "core_config = mct.core.CoreConfig(mixed_precision_config=mp_config)\n", "quant_model, _ = mct.ptq.keras_post_training_quantization(\n", " model,\n", " get_representative_dataset(20),\n", + " target_kpi=kpi,\n", " core_config=core_config,\n", " target_platform_capabilities=tpc)" ], diff --git a/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision.ipynb b/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision.ipynb index d5a0854b8..3e67bc606 100644 --- a/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision.ipynb +++ b/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision.ipynb @@ -266,8 +266,7 @@ "# while the bias will not)\n", "# examples:\n", "weights_compression_ratio = 0.75 # About 0.75 of the model's weights memory size when quantized with 8 bits.\n", - "kpi = mct.core.KPI(kpi_data.weights_memory * weights_compression_ratio)\n", - "core_config.mixed_precision_config.set_target_kpi(kpi)" + "kpi = mct.core.KPI(kpi_data.weights_memory * weights_compression_ratio)" ], "metadata": { "collapsed": false @@ -297,6 +296,7 @@ "quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(\n", " float_model,\n", " representative_dataset_gen,\n", + " target_kpi=kpi,\n", " core_config=core_config,\n", " target_platform_capabilities=tpc)" ] diff --git a/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision.py b/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision.py index 7df6e0ee0..934cfb9ce 100644 --- a/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision.py +++ b/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision.py @@ -128,12 +128,12 @@ def representative_data_gen() -> list: # examples: # weights_compression_ratio = 0.75 - About 0.75 of the model's weights memory size when quantized with 8 bits. kpi = mct.core.KPI(kpi_data.weights_memory * args.weights_compression_ratio) - configuration.mixed_precision_config.set_target_kpi(kpi) # It is also possible to constraint only part of the KPI metric, e.g., by providing only weights_memory target # in the past KPI object, e.g., kpi = mct.core.KPI(kpi_data.weights_memory * 0.75) quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(model, representative_data_gen, + target_kpi=kpi, core_config=configuration, target_platform_capabilities=target_platform_cap) diff --git a/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision_lut.py b/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision_lut.py index 1664e119a..efcf0f41f 100644 --- a/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision_lut.py +++ b/tutorials/notebooks/keras/ptq/example_keras_mobilenet_mixed_precision_lut.py @@ -134,10 +134,10 @@ def representative_data_gen() -> list: # weights_compression_ratio = 0.4 - About 0.4 of the model's weights memory size when quantized with 8 bits. kpi = mct.core.KPI(kpi_data.weights_memory * args.weights_compression_ratio) # Note that in this example, activations are quantized with fixed bit-width (non mixed-precision) of 8-bit. - configuration.mixed_precision_config.set_target_kpi(kpi) quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(model, representative_data_gen, + target_kpi=kpi, core_config=configuration, target_platform_capabilities=target_platform_cap) diff --git a/tutorials/notebooks/keras/ptq/example_keras_yolov8n.ipynb b/tutorials/notebooks/keras/ptq/example_keras_yolov8n.ipynb index 3cc4412bb..f32db0e74 100644 --- a/tutorials/notebooks/keras/ptq/example_keras_yolov8n.ipynb +++ b/tutorials/notebooks/keras/ptq/example_keras_yolov8n.ipynb @@ -297,11 +297,11 @@ " config,\n", " target_platform_capabilities=tpc)\n", "kpi = mct.core.KPI(kpi_data.weights_memory * 0.75)\n", - "config.mixed_precision_config.set_target_kpi(kpi)\n", "\n", "# Perform post training quantization\n", "quant_model, _ = mct.ptq.keras_post_training_quantization(model,\n", " representative_dataset_gen,\n", + " target_kpi=kpi,\n", " core_config=config,\n", " target_platform_capabilities=tpc)\n", "print('Quantized model is ready')" diff --git a/tutorials/notebooks/keras/ptq/keras_yolov8n_for_imx500.ipynb b/tutorials/notebooks/keras/ptq/keras_yolov8n_for_imx500.ipynb index 9ef8a4a04..f0fb703a6 100644 --- a/tutorials/notebooks/keras/ptq/keras_yolov8n_for_imx500.ipynb +++ b/tutorials/notebooks/keras/ptq/keras_yolov8n_for_imx500.ipynb @@ -193,11 +193,11 @@ " config,\n", " target_platform_capabilities=tpc)\n", "kpi = mct.core.KPI(kpi_data.weights_memory * 0.75)\n", - "config.mixed_precision_config.set_target_kpi(kpi)\n", "\n", "# Perform post training quantization\n", "quant_model, _ = mct.ptq.keras_post_training_quantization(model,\n", " representative_dataset_gen,\n", + " target_kpi=kpi,\n", " core_config=config,\n", " target_platform_capabilities=tpc)\n", "print('Quantized model is ready')" diff --git a/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenet_mixed_precision.py b/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenet_mixed_precision.py index 759f3c258..a5bf0ff02 100644 --- a/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenet_mixed_precision.py +++ b/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenet_mixed_precision.py @@ -121,13 +121,13 @@ def representative_data_gen() -> list: # examples: # weights_compression_ratio = 0.75 - About 0.75 of the model's weights memory size when quantized with 8 bits. kpi = mct.core.KPI(kpi_data.weights_memory * args.weights_compression_ratio) - configuration.mixed_precision_config.set_target_kpi(kpi) # It is also possible to constraint only part of the KPI metric, e.g., by providing only weights_memory target # in the past KPI object, e.g., kpi = mct.core.KPI(kpi_data.weights_memory * 0.75) quantized_model, quantization_info = mct.ptq.pytorch_post_training_quantization(model, representative_data_gen, + target_kpi=kpi, core_config=configuration, target_platform_capabilities=target_platform_cap) diff --git a/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenet_mixed_precision_lut.py b/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenet_mixed_precision_lut.py index bcea906be..1acff8dc2 100644 --- a/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenet_mixed_precision_lut.py +++ b/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenet_mixed_precision_lut.py @@ -126,10 +126,10 @@ def representative_data_gen() -> list: # weights_compression_ratio = 0.4 - About 0.4 of the model's weights memory size when quantized with 8 bits. kpi = mct.core.KPI(kpi_data.weights_memory * args.weights_compression_ratio) # Note that in this example, activations are quantized with fixed bit-width (non mixed-precision) of 8-bit. - configuration.mixed_precision_config.set_target_kpi(kpi) quantized_model, quantization_info = mct.ptq.pytorch_post_training_quantization(model, representative_data_gen, + target_kpi=kpi, core_config=configuration, target_platform_capabilities=target_platform_cap) diff --git a/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenetv2_cifar100_mixed_precision.ipynb b/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenetv2_cifar100_mixed_precision.ipynb index 0a4ca7abc..6e967f7fd 100644 --- a/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenetv2_cifar100_mixed_precision.ipynb +++ b/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenetv2_cifar100_mixed_precision.ipynb @@ -515,8 +515,7 @@ "# while the bias will not)\n", "# examples:\n", "# weights_compression_ratio = 0.75 - About 0.75 of the model's weights memory size when quantized with 8 bits.\n", - "kpi = mct.core.KPI(kpi_data.weights_memory * 0.75)\n", - "configuration.mixed_precision_config.set_target_kpi(kpi)" + "kpi = mct.core.KPI(kpi_data.weights_memory * 0.75)" ] }, { @@ -538,6 +537,7 @@ "source": [ "quantized_model, quantization_info = mct.ptq.pytorch_post_training_quantization(model,\n", " representative_data_gen,\n", + " target_kpi=kpi,\n", " core_config=configuration,\n", " target_platform_capabilities=target_platform_cap)\n", " " diff --git a/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenetv2_cifar100_mixed_precision.py b/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenetv2_cifar100_mixed_precision.py index d82c510ce..57c0ae9a1 100644 --- a/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenetv2_cifar100_mixed_precision.py +++ b/tutorials/notebooks/pytorch/ptq/example_pytorch_mobilenetv2_cifar100_mixed_precision.py @@ -239,12 +239,12 @@ def representative_data_gen() -> list: # examples: # weights_compression_ratio = 0.75 - About 0.75 of the model's weights memory size when quantized with 8 bits. kpi = mct.core.KPI(kpi_data.weights_memory * args.weights_compression_ratio) - configuration.mixed_precision_config.set_target_kpi(kpi) # It is also possible to constraint only part of the KPI metric, e.g., by providing only weights_memory target # in the past KPI object, e.g., kpi = mct.core.KPI(kpi_data.weights_memory * 0.75) quantized_model, quantization_info = mct.ptq.pytorch_post_training_quantization(model, representative_data_gen, + target_kpi=kpi, core_config=configuration, target_platform_capabilities=target_platform_cap) # Finally, we evaluate the quantized model: diff --git a/tutorials/quick_start/keras_fw/quant.py b/tutorials/quick_start/keras_fw/quant.py index 2133fc3c6..1365c929c 100644 --- a/tutorials/quick_start/keras_fw/quant.py +++ b/tutorials/quick_start/keras_fw/quant.py @@ -101,10 +101,10 @@ def quantize(model: tf.keras.Model, shift_negative_activation_correction=True), mixed_precision_config=mp_conf) target_kpi = get_target_kpi(model, mp_wcr, representative_data_gen, core_conf, tpc) - core_conf.mixed_precision_config.set_target_kpi(target_kpi) else: core_conf = CoreConfig(quantization_config=mct.core.QuantizationConfig( shift_negative_activation_correction=True)) + target_kpi = None # Quantize model if args.get('gptq', False): @@ -117,10 +117,13 @@ def quantize(model: tf.keras.Model, gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=n_epochs, optimizer=Adam(learning_rate=args['gptq_lr'])) quantized_model, quantization_info = \ - mct.gptq.keras_gradient_post_training_quantization(model, representative_data_gen=representative_data_gen, + mct.gptq.keras_gradient_post_training_quantization(model, + representative_data_gen=representative_data_gen, + target_kpi=target_kpi, + core_config=core_conf, gptq_config=gptq_conf, gptq_representative_data_gen=representative_data_gen, - core_config=core_conf, target_platform_capabilities=tpc) + target_platform_capabilities=tpc) else: @@ -128,6 +131,7 @@ def quantize(model: tf.keras.Model, quantized_model, quantization_info = \ mct.ptq.keras_post_training_quantization(model, representative_data_gen=representative_data_gen, + target_kpi=target_kpi, core_config=core_conf, target_platform_capabilities=tpc) diff --git a/tutorials/quick_start/pytorch_fw/quant.py b/tutorials/quick_start/pytorch_fw/quant.py index fd8a02058..497ce3ac8 100644 --- a/tutorials/quick_start/pytorch_fw/quant.py +++ b/tutorials/quick_start/pytorch_fw/quant.py @@ -102,10 +102,10 @@ def quantize(model: nn.Module, shift_negative_activation_correction=True), mixed_precision_config=mp_conf) target_kpi = get_target_kpi(model, mp_wcr, representative_data_gen, core_conf, tpc) - core_conf.mixed_precision_config.set_target_kpi(target_kpi) else: core_conf = CoreConfig(quantization_config=mct.core.QuantizationConfig( shift_negative_activation_correction=True)) + target_kpi = None # Quantize model if args.get('gptq', False): @@ -120,6 +120,7 @@ def quantize(model: nn.Module, quantized_model, quantization_info = \ mct.gptq.pytorch_gradient_post_training_quantization(model, representative_data_gen=representative_data_gen, + target_kpi=target_kpi, core_config=core_conf, gptq_config=gptq_conf, gptq_representative_data_gen=representative_data_gen,