Skip to content

Commit

Permalink
Remove fw info from API signatures (#994)
Browse files Browse the repository at this point in the history
* Remove fw info from API signatures:  keras_kpi_data, pytorch_kpi_data, keras_gradient_post_training_quantization, keras_quantization_aware_training_init_experimental, pytorch_quantization_aware_training_init_experimental.

---------

Co-authored-by: reuvenp <[email protected]>
  • Loading branch information
reuvenperetz and reuvenp authored Mar 14, 2024
1 parent a5b5365 commit cab0f25
Show file tree
Hide file tree
Showing 24 changed files with 58 additions and 100 deletions.
4 changes: 1 addition & 3 deletions model_compression_toolkit/core/keras/kpi_data_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
def keras_kpi_data(in_model: Model,
representative_data_gen: Callable,
core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC) -> KPI:
"""
Computes KPI data that can be used to calculate the desired target KPI for mixed-precision quantization.
Expand All @@ -46,7 +45,6 @@ def keras_kpi_data(in_model: Model,
in_model (Model): Keras model to quantize.
representative_data_gen (Callable): Dataset used for calibration.
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision of how the model should be quantized.
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
Returns:
Expand Down Expand Up @@ -82,7 +80,7 @@ def keras_kpi_data(in_model: Model,
representative_data_gen,
core_config,
target_platform_capabilities,
fw_info,
DEFAULT_KERAS_INFO,
fw_impl)

else:
Expand Down
4 changes: 1 addition & 3 deletions model_compression_toolkit/core/pytorch/kpi_data_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
def pytorch_kpi_data(in_model: Module,
representative_data_gen: Callable,
core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC) -> KPI:
"""
Computes KPI data that can be used to calculate the desired target KPI for mixed-precision quantization.
Expand All @@ -49,7 +48,6 @@ def pytorch_kpi_data(in_model: Module,
in_model (Model): PyTorch model to quantize.
representative_data_gen (Callable): Dataset used for calibration.
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default PyTorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
Returns:
Expand Down Expand Up @@ -85,7 +83,7 @@ def pytorch_kpi_data(in_model: Module,
representative_data_gen,
core_config,
target_platform_capabilities,
fw_info,
DEFAULT_PYTORCH_INFO,
fw_impl)

else:
Expand Down
13 changes: 5 additions & 8 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,10 @@ def get_keras_gptq_config(n_epochs: int,
regularization_factor=regularization_factor)


def keras_gradient_post_training_quantization(in_model: Model,
representative_data_gen: Callable,
def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable,
gptq_config: GradientPTQConfig,
gptq_representative_data_gen: Callable = None,
core_config: CoreConfig = CoreConfig(),
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
"""
Quantize a trained Keras model using post-training quantization. The model is quantized using a
Expand All @@ -142,7 +140,6 @@ def keras_gradient_post_training_quantization(in_model: Model,
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
Returns:
Expand Down Expand Up @@ -192,22 +189,22 @@ def keras_gradient_post_training_quantization(in_model: Model,
"""
KerasModelValidation(model=in_model,
fw_info=fw_info).validate()
fw_info=DEFAULT_KERAS_INFO).validate()

if core_config.mixed_precision_enable:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.error("Given quantization config to mixed-precision facade is not of type "
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
"API, or pass a valid mixed precision configuration.") # pragma: no cover

tb_w = init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)

fw_impl = GPTQKerasImplemantation()

tg, bit_widths_config, hessian_info_service = core_runner(in_model=in_model,
representative_data_gen=representative_data_gen,
core_config=core_config,
fw_info=fw_info,
fw_info=DEFAULT_KERAS_INFO,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
tb_w=tb_w)
Expand All @@ -217,7 +214,7 @@ def keras_gradient_post_training_quantization(in_model: Model,
gptq_config,
representative_data_gen,
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
fw_info,
DEFAULT_KERAS_INFO,
fw_impl,
tb_w,
hessian_info_service=hessian_info_service)
Expand Down
14 changes: 6 additions & 8 deletions model_compression_toolkit/qat/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def keras_quantization_aware_training_init_experimental(in_model: Model,
representative_data_gen: Callable,
core_config: CoreConfig = CoreConfig(),
qat_config: QATConfig = QATConfig(),
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
"""
Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
Expand All @@ -111,7 +110,6 @@ def keras_quantization_aware_training_init_experimental(in_model: Model,
representative_data_gen (Callable): Dataset used for initial calibration.
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
qat_config (QATConfig): QAT configuration
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
Returns:
Expand Down Expand Up @@ -159,7 +157,7 @@ def keras_quantization_aware_training_init_experimental(in_model: Model,
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
quantized model:
>>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model, repr_datagen, kpi, core_config=config)
>>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model, repr_datagen, kpi, core_config=core_config)
Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary:
Expand All @@ -174,32 +172,32 @@ def keras_quantization_aware_training_init_experimental(in_model: Model,
f"project https://github.com/sony/model_optimization")

KerasModelValidation(model=in_model,
fw_info=fw_info).validate()
fw_info=DEFAULT_KERAS_INFO).validate()

if core_config.mixed_precision_enable:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.error("Given quantization config to mixed-precision facade is not of type "
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
"or pass a valid mixed precision configuration.")

tb_w = init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)

fw_impl = KerasImplementation()

# Ignore hessian service since is not used in QAT at the moment
tg, bit_widths_config, _ = core_runner(in_model=in_model,
representative_data_gen=representative_data_gen,
core_config=core_config,
fw_info=fw_info,
fw_info=DEFAULT_KERAS_INFO,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
tb_w=tb_w)

tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_KERAS_INFO, fw_impl, tb_w)

_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
qat_model, user_info = KerasModelBuilder(graph=tg,
fw_info=fw_info,
fw_info=DEFAULT_KERAS_INFO,
wrapper=_qat_wrapper,
get_activation_quantizer_holder_fn=partial(get_activation_quantizer_holder,
qat_config=qat_config)).build_model()
Expand Down
8 changes: 3 additions & 5 deletions model_compression_toolkit/qat/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module,
representative_data_gen: Callable,
core_config: CoreConfig = CoreConfig(),
qat_config: QATConfig = QATConfig(),
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
"""
Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
Expand All @@ -99,7 +98,6 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module,
representative_data_gen (Callable): Dataset used for initial calibration.
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
qat_config (QATConfig): QAT configuration
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Pytorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Pytorch model according to.
Returns:
Expand Down Expand Up @@ -150,7 +148,7 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module,
"MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
"or pass a valid mixed precision configuration.")

tb_w = init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
fw_impl = PytorchImplementation()

# Ignore trace hessian service as we do not use it here
Expand All @@ -162,12 +160,12 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module,
tpc=target_platform_capabilities,
tb_w=tb_w)

tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)

_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)

qat_model, user_info = PyTorchModelBuilder(graph=tg,
fw_info=fw_info,
fw_info=DEFAULT_PYTORCH_INFO,
wrapper=_qat_wrapper,
get_activation_quantizer_holder_fn=partial(
get_activation_quantizer_holder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,11 @@ def representative_data_gen():
core_config=core_config,
target_platform_capabilities=tpc
)
ptq_gptq_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(
model_float,
representative_data_gen,
gptq_config=self.get_gptq_config(),
core_config=core_config,
target_platform_capabilities=tpc
)
ptq_gptq_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model_float,
representative_data_gen,
gptq_config=self.get_gptq_config(),
core_config=core_config,
target_platform_capabilities=tpc)

self.compare(ptq_model, ptq_gptq_model, input_x=x, quantization_info=quantization_info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ def get_tpc(self):

def run_test(self, **kwargs):
model_float = self.create_networks()
ptq_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model_float,
self.representative_data_gen,
fw_info=self.get_fw_info(),
target_platform_capabilities=self.get_tpc())
ptq_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(
model_float, self.representative_data_gen, target_platform_capabilities=self.get_tpc())

ptq_model2 = None
if self.test_loading:
Expand Down Expand Up @@ -197,11 +195,9 @@ def get_qat_config(self):

def run_test(self, **kwargs):
model_float = self.create_networks()
ptq_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model_float,
self.representative_data_gen,
fw_info=self.get_fw_info(),
qat_config=self.get_qat_config(),
target_platform_capabilities=self.get_tpc())
ptq_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(
model_float, self.representative_data_gen, qat_config=self.get_qat_config(),
target_platform_capabilities=self.get_tpc())

# PTQ model
in_tensor = np.random.randn(1, *ptq_model.input_shape[1:])
Expand Down Expand Up @@ -300,10 +296,7 @@ def run_test(self, **kwargs):
mct.core.KPI(weights_memory=self.kpi_weights,
activation_memory=self.kpi_activation)))
qat_ready_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(
model_float,
self.representative_data_gen_experimental,
core_config=config,
fw_info=self.get_fw_info(),
model_float, self.representative_data_gen_experimental, core_config=config,
target_platform_capabilities=self.get_tpc())

self.compare(qat_ready_model, quantization_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def export_qat_model(self):
model = self.get_model()
images = next(self.get_dataset())

self.qat_ready, _, _ = mct.qat.keras_quantization_aware_training_init_experimental(model,
self.get_dataset)
self.qat_ready, _, _ = mct.qat.keras_quantization_aware_training_init_experimental(model, self.get_dataset)
_qat_ready_model_path = tempfile.mkstemp('.h5')[1]
keras.models.save_model(self.qat_ready, _qat_ready_model_path)
self.qat_ready = mct_load(_qat_ready_model_path)
Expand Down
Loading

0 comments on commit cab0f25

Please sign in to comment.