diff --git a/model_compression_toolkit/gptq/common/gptq_training.py b/model_compression_toolkit/gptq/common/gptq_training.py index 04a39e2be..bee23d565 100644 --- a/model_compression_toolkit/gptq/common/gptq_training.py +++ b/model_compression_toolkit/gptq/common/gptq_training.py @@ -27,7 +27,11 @@ from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points +from model_compression_toolkit.gptq.common.gradual_activation_quantization import \ + get_gradual_activation_quantizer_wrapper_factory +from model_compression_toolkit.gptq.common.regularization_factory import get_regularization from model_compression_toolkit.logger import Logger +from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps class GPTQTrainer(ABC): @@ -64,6 +68,14 @@ def __init__(self, self.fw_impl = fw_impl self.fw_info = fw_info self.representative_data_gen_fn = representative_data_gen_fn + + def _get_total_grad_steps(): + return get_total_grad_steps(representative_data_gen_fn) * gptq_config.n_epochs + + self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(gptq_config, + _get_total_grad_steps, + self.fw_linear_annealing_scheduler) + # ---------------------------------------------- # Build two models and create compare nodes # ---------------------------------------------- @@ -81,6 +93,53 @@ def __init__(self, f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.") # pragma: no cover self.hessian_service = hessian_info_service + + self.reg_func = get_regularization(self.gptq_config, + _get_total_grad_steps, + self.fw_soft_quantizer_regularization, + self.fw_linear_annealing_scheduler) + self.loss_list = [] + self.input_scale = 1 + if self.float_user_info.input_scale != self.gptq_user_info.input_scale: + Logger.critical("Input scale mismatch between float and GPTQ networks. " + "Ensure both networks have matching input scales.") # pragma: no cover + else: + self.input_scale = self.gptq_user_info.input_scale + + trainable_weights, trainable_bias, trainable_threshold = self.fw_get_gptq_trainable_parameters_fn( + self.fxp_model, + add_bias=self.gptq_config.train_bias) + self.flp_weights_list, self.fxp_weights_list = self.fw_get_weights_for_loss_fn(self.fxp_model) + + if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len( + self.fxp_weights_list)): + Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, " + "and the number of float and quantized weights for loss calculation. " + "Ensure all these elements align to proceed with GPTQ training.") + + # In Keras we need to flatten the weights first before attaching the optimizer + if isinstance(trainable_weights[0], (list, tuple)): + trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights] + if isinstance(trainable_bias[0], (list, tuple)): + trainable_bias = [w for layer_weights in trainable_bias for w in layer_weights] + + self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights, + trainable_bias, + trainable_threshold) + hessian_cfg = self.gptq_config.hessian_weights_config + + self.has_params_to_train = np.sum( + [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0 + self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample + + if self.use_sample_layer_attention: + # normalization is currently not supported, make sure the config reflects it. + if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm: + raise NotImplementedError() + self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen) + else: + self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen) + def get_optimizer_with_param(self, flattened_trainable_weights: List[Any], flattened_bias_weights: List[Any], diff --git a/model_compression_toolkit/gptq/keras/gptq_training.py b/model_compression_toolkit/gptq/keras/gptq_training.py index 45b60b191..581204641 100644 --- a/model_compression_toolkit/gptq/keras/gptq_training.py +++ b/model_compression_toolkit/gptq/keras/gptq_training.py @@ -85,13 +85,10 @@ def __init__(self, """ - def _get_total_grad_steps(): - return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs - - # This must be set before the model building (as it is required for activation holder construction), - # which occurs in the base constructor. - self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory( - gptq_config, _get_total_grad_steps, KerasLinearAnnealingScheduler) + self.fw_soft_quantizer_regularization = SoftQuantizerRegularization + self.fw_linear_annealing_scheduler = KerasLinearAnnealingScheduler + self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters + self.fw_get_weights_for_loss_fn = get_weights_for_loss super().__init__(graph_float, graph_quant, @@ -101,52 +98,48 @@ def _get_total_grad_steps(): representative_data_gen_fn=representative_data_gen, hessian_info_service=hessian_info_service) - self.loss_list = [] - self.input_scale = 1 - - trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters( - self.fxp_model, - fw_info, - add_bias=gptq_config.train_bias) - - self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model) - - if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len( - self.fxp_weights_list)): - Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, " - "and the number of float and quantized weights for loss calculation. " - "Ensure all these elements align to proceed with GPTQ training.") - - flattened_trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights] - flattened_bias_weights = [w for layer_weights in bias_weights for w in layer_weights] - trainable_quantization_parameters = trainable_threshold - self.optimizer_with_param = self.get_optimizer_with_param(flattened_trainable_weights, - flattened_bias_weights, - trainable_quantization_parameters) - self.has_params_to_train = np.sum( - [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0 - - if self.float_user_info.input_scale != self.gptq_user_info.input_scale: - Logger.critical("Input scale mismatch detected between the float model and the GPTQ model. " - "Confirm that the input scales for both models are correctly configured and aligned.") # pragma: no cover - else: - self.input_scale = self.gptq_user_info.input_scale - - hessian_cfg = gptq_config.hessian_weights_config - self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample - - if self.use_sample_layer_attention: - # normalization is currently not supported, make sure the config reflects it. - if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm: - raise NotImplementedError() - self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen) - else: - self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen) + # self.loss_list = [] + # self.input_scale = 1 + + # trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters( + # self.fxp_model, + # add_bias=gptq_config.train_bias) + + # self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model) + # + # if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len( + # self.fxp_weights_list)): + # Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, " + # "and the number of float and quantized weights for loss calculation. " + # "Ensure all these elements align to proceed with GPTQ training.") + + # flattened_trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights] + # flattened_bias_weights = [w for layer_weights in bias_weights for w in layer_weights] + # trainable_quantization_parameters = trainable_threshold + # + # self.optimizer_with_param = self.get_optimizer_with_param(flattened_trainable_weights, + # flattened_bias_weights, + # trainable_quantization_parameters) + # self.has_params_to_train = np.sum( + # [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0 + + # if self.float_user_info.input_scale != self.gptq_user_info.input_scale: + # Logger.critical("Input scale mismatch detected between the float model and the GPTQ model. " + # "Confirm that the input scales for both models are correctly configured and aligned.") # pragma: no cover + # else: + # self.input_scale = self.gptq_user_info.input_scale + + # hessian_cfg = gptq_config.hessian_weights_config + # self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample + # + # if self.use_sample_layer_attention: + # # normalization is currently not supported, make sure the config reflects it. + # if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm: + # raise NotImplementedError() + # self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen) + # else: + # self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen) - self.reg_func = get_regularization(self.gptq_config, - _get_total_grad_steps, - SoftQuantizerRegularization, - KerasLinearAnnealingScheduler) def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> tf.data.Dataset: """ diff --git a/model_compression_toolkit/gptq/keras/graph_info.py b/model_compression_toolkit/gptq/keras/graph_info.py index 4b394adae..f3f0dec6b 100644 --- a/model_compression_toolkit/gptq/keras/graph_info.py +++ b/model_compression_toolkit/gptq/keras/graph_info.py @@ -16,7 +16,6 @@ import tensorflow as tf from typing import Tuple, List from model_compression_toolkit.core.keras.constants import USE_BIAS -from model_compression_toolkit.core.common.framework_info import FrameworkInfo from tensorflow.keras.models import Model from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq @@ -26,7 +25,6 @@ def get_gptq_trainable_parameters(fxp_model: Model, - fw_info: FrameworkInfo, add_bias: bool = False) -> ( List[tf.Variable], List[tf.Variable], List[tf.Variable]): """ @@ -34,7 +32,6 @@ def get_gptq_trainable_parameters(fxp_model: Model, Args: fxp_model: Model to get its trainable parameters. - fw_info: Framework information needed for keras kernel ops list. add_bias: Whether to include biases of the model (if there are) or not. Returns: @@ -60,7 +57,7 @@ def get_gptq_trainable_parameters(fxp_model: Model, trainable_threshold.extend(quantizer_trainable_threshold) if add_bias: - kernel_ops_attrs = fw_info.kernel_ops_attributes_mapping.get(type(layer.layer)) + kernel_ops_attrs = DEFAULT_KERAS_INFO.kernel_ops_attributes_mapping.get(type(layer.layer)) use_bias = kernel_ops_attrs is not None and kernel_ops_attrs[0] is not None \ and layer.layer.get_config().get(USE_BIAS) if use_bias is not None and use_bias and layer.layer.bias is not None: diff --git a/model_compression_toolkit/gptq/pytorch/gptq_training.py b/model_compression_toolkit/gptq/pytorch/gptq_training.py index ee91a36d3..ad69afd8e 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_training.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_training.py @@ -76,13 +76,10 @@ def __init__(self, representative_data_gen: Dataset to use for inputs of the models. hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model. """ - def _get_total_grad_steps(): - # TODO get it from the dataset - return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs - - # must be set prior to model building in the base class constructor - self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory( - gptq_config, _get_total_grad_steps, PytorchLinearAnnealingScheduler) + self.fw_soft_quantizer_regularization = PytorchSoftQuantizerRegularization + self.fw_linear_annealing_scheduler = PytorchLinearAnnealingScheduler + self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters + self.fw_get_weights_for_loss_fn = get_weights_for_loss super().__init__(graph_float, graph_quant, @@ -92,40 +89,41 @@ def _get_total_grad_steps(): representative_data_gen_fn=representative_data_gen, hessian_info_service=hessian_info_service) - self.loss_list = [] - self.input_scale = 1 - if self.float_user_info.input_scale != self.gptq_user_info.input_scale: - Logger.critical("Input scale mismatch between float and GPTQ networks. " - "Ensure both networks have matching input scales.") # pragma: no cover - else: - self.input_scale = self.gptq_user_info.input_scale - - trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters( - self.fxp_model, - add_bias=self.gptq_config.train_bias) - - self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model) - if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len( - self.fxp_weights_list)): - Logger.critical("GPTQ: Number of comparison points, layers with trainable weights, " - "and float vs. quantized weights for loss calculation do not match. " - "Verify consistency across these parameters for successful GPTQ training.") - - self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights, - trainable_bias, - trainable_threshold) - hessian_cfg = self.gptq_config.hessian_weights_config - - self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample - if self.use_sample_layer_attention: - # normalization is currently not supported, make sure the config reflects it. - if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm: - raise NotImplementedError() - self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen) - else: - self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen) + # self.loss_list = [] + # self.input_scale = 1 + # if self.float_user_info.input_scale != self.gptq_user_info.input_scale: + # Logger.critical("Input scale mismatch between float and GPTQ networks. " + # "Ensure both networks have matching input scales.") # pragma: no cover + # else: + # self.input_scale = self.gptq_user_info.input_scale + + # trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters( + # self.fxp_model, + # add_bias=self.gptq_config.train_bias) + + + # self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model) + # if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len( + # self.fxp_weights_list)): + # Logger.critical("GPTQ: Number of comparison points, layers with trainable weights, " + # "and float vs. quantized weights for loss calculation do not match. " + # "Verify consistency across these parameters for successful GPTQ training.") + + # self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights, + # trainable_bias, + # trainable_threshold) + # hessian_cfg = self.gptq_config.hessian_weights_config + + # self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample + # if self.use_sample_layer_attention: + # # normalization is currently not supported, make sure the config reflects it. + # if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm: + # raise NotImplementedError() + # self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen) + # else: + # self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen) + - self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps, PytorchSoftQuantizerRegularization, PytorchLinearAnnealingScheduler) def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader: """