Skip to content

Commit

Permalink
Move gptq trainer init logic to common trainer init
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Dec 4, 2024
1 parent 227e271 commit 10b80e1
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 96 deletions.
59 changes: 59 additions & 0 deletions model_compression_toolkit/gptq/common/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
from model_compression_toolkit.gptq.common.gradual_activation_quantization import \
get_gradual_activation_quantizer_wrapper_factory
from model_compression_toolkit.gptq.common.regularization_factory import get_regularization
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps


class GPTQTrainer(ABC):
Expand Down Expand Up @@ -64,6 +68,14 @@ def __init__(self,
self.fw_impl = fw_impl
self.fw_info = fw_info
self.representative_data_gen_fn = representative_data_gen_fn

def _get_total_grad_steps():
return get_total_grad_steps(representative_data_gen_fn) * gptq_config.n_epochs

self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(gptq_config,
_get_total_grad_steps,
self.fw_linear_annealing_scheduler)

# ----------------------------------------------
# Build two models and create compare nodes
# ----------------------------------------------
Expand All @@ -81,6 +93,53 @@ def __init__(self,
f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.") # pragma: no cover
self.hessian_service = hessian_info_service


self.reg_func = get_regularization(self.gptq_config,
_get_total_grad_steps,
self.fw_soft_quantizer_regularization,
self.fw_linear_annealing_scheduler)
self.loss_list = []
self.input_scale = 1
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
Logger.critical("Input scale mismatch between float and GPTQ networks. "
"Ensure both networks have matching input scales.") # pragma: no cover
else:
self.input_scale = self.gptq_user_info.input_scale

trainable_weights, trainable_bias, trainable_threshold = self.fw_get_gptq_trainable_parameters_fn(
self.fxp_model,
add_bias=self.gptq_config.train_bias)
self.flp_weights_list, self.fxp_weights_list = self.fw_get_weights_for_loss_fn(self.fxp_model)

if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
self.fxp_weights_list)):
Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
"and the number of float and quantized weights for loss calculation. "
"Ensure all these elements align to proceed with GPTQ training.")

# In Keras we need to flatten the weights first before attaching the optimizer
if isinstance(trainable_weights[0], (list, tuple)):
trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
if isinstance(trainable_bias[0], (list, tuple)):
trainable_bias = [w for layer_weights in trainable_bias for w in layer_weights]

self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
trainable_bias,
trainable_threshold)
hessian_cfg = self.gptq_config.hessian_weights_config

self.has_params_to_train = np.sum(
[len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample

if self.use_sample_layer_attention:
# normalization is currently not supported, make sure the config reflects it.
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
raise NotImplementedError()
self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
else:
self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)

def get_optimizer_with_param(self,
flattened_trainable_weights: List[Any],
flattened_bias_weights: List[Any],
Expand Down
97 changes: 45 additions & 52 deletions model_compression_toolkit/gptq/keras/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,10 @@ def __init__(self,
"""

def _get_total_grad_steps():
return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs

# This must be set before the model building (as it is required for activation holder construction),
# which occurs in the base constructor.
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
gptq_config, _get_total_grad_steps, KerasLinearAnnealingScheduler)
self.fw_soft_quantizer_regularization = SoftQuantizerRegularization
self.fw_linear_annealing_scheduler = KerasLinearAnnealingScheduler
self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters
self.fw_get_weights_for_loss_fn = get_weights_for_loss

super().__init__(graph_float,
graph_quant,
Expand All @@ -101,52 +98,48 @@ def _get_total_grad_steps():
representative_data_gen_fn=representative_data_gen,
hessian_info_service=hessian_info_service)

self.loss_list = []
self.input_scale = 1

trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
self.fxp_model,
fw_info,
add_bias=gptq_config.train_bias)

self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)

if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
self.fxp_weights_list)):
Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
"and the number of float and quantized weights for loss calculation. "
"Ensure all these elements align to proceed with GPTQ training.")

flattened_trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
flattened_bias_weights = [w for layer_weights in bias_weights for w in layer_weights]
trainable_quantization_parameters = trainable_threshold
self.optimizer_with_param = self.get_optimizer_with_param(flattened_trainable_weights,
flattened_bias_weights,
trainable_quantization_parameters)
self.has_params_to_train = np.sum(
[len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0

if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
Logger.critical("Input scale mismatch detected between the float model and the GPTQ model. "
"Confirm that the input scales for both models are correctly configured and aligned.") # pragma: no cover
else:
self.input_scale = self.gptq_user_info.input_scale

hessian_cfg = gptq_config.hessian_weights_config
self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample

if self.use_sample_layer_attention:
# normalization is currently not supported, make sure the config reflects it.
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
raise NotImplementedError()
self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
else:
self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)
# self.loss_list = []
# self.input_scale = 1

# trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
# self.fxp_model,
# add_bias=gptq_config.train_bias)

# self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
#
# if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
# self.fxp_weights_list)):
# Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
# "and the number of float and quantized weights for loss calculation. "
# "Ensure all these elements align to proceed with GPTQ training.")

# flattened_trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
# flattened_bias_weights = [w for layer_weights in bias_weights for w in layer_weights]
# trainable_quantization_parameters = trainable_threshold
#
# self.optimizer_with_param = self.get_optimizer_with_param(flattened_trainable_weights,
# flattened_bias_weights,
# trainable_quantization_parameters)
# self.has_params_to_train = np.sum(
# [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0

# if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
# Logger.critical("Input scale mismatch detected between the float model and the GPTQ model. "
# "Confirm that the input scales for both models are correctly configured and aligned.") # pragma: no cover
# else:
# self.input_scale = self.gptq_user_info.input_scale

# hessian_cfg = gptq_config.hessian_weights_config
# self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
#
# if self.use_sample_layer_attention:
# # normalization is currently not supported, make sure the config reflects it.
# if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
# raise NotImplementedError()
# self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
# else:
# self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)

self.reg_func = get_regularization(self.gptq_config,
_get_total_grad_steps,
SoftQuantizerRegularization,
KerasLinearAnnealingScheduler)

def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> tf.data.Dataset:
"""
Expand Down
5 changes: 1 addition & 4 deletions model_compression_toolkit/gptq/keras/graph_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import tensorflow as tf
from typing import Tuple, List
from model_compression_toolkit.core.keras.constants import USE_BIAS
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from tensorflow.keras.models import Model
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
Expand All @@ -26,15 +25,13 @@


def get_gptq_trainable_parameters(fxp_model: Model,
fw_info: FrameworkInfo,
add_bias: bool = False) -> (
List[tf.Variable], List[tf.Variable], List[tf.Variable]):
"""
Get trainable parameters from all layers in a model
Args:
fxp_model: Model to get its trainable parameters.
fw_info: Framework information needed for keras kernel ops list.
add_bias: Whether to include biases of the model (if there are) or not.
Returns:
Expand All @@ -60,7 +57,7 @@ def get_gptq_trainable_parameters(fxp_model: Model,
trainable_threshold.extend(quantizer_trainable_threshold)

if add_bias:
kernel_ops_attrs = fw_info.kernel_ops_attributes_mapping.get(type(layer.layer))
kernel_ops_attrs = DEFAULT_KERAS_INFO.kernel_ops_attributes_mapping.get(type(layer.layer))
use_bias = kernel_ops_attrs is not None and kernel_ops_attrs[0] is not None \
and layer.layer.get_config().get(USE_BIAS)
if use_bias is not None and use_bias and layer.layer.bias is not None:
Expand Down
78 changes: 38 additions & 40 deletions model_compression_toolkit/gptq/pytorch/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,10 @@ def __init__(self,
representative_data_gen: Dataset to use for inputs of the models.
hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
"""
def _get_total_grad_steps():
# TODO get it from the dataset
return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs

# must be set prior to model building in the base class constructor
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
gptq_config, _get_total_grad_steps, PytorchLinearAnnealingScheduler)
self.fw_soft_quantizer_regularization = PytorchSoftQuantizerRegularization
self.fw_linear_annealing_scheduler = PytorchLinearAnnealingScheduler
self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters
self.fw_get_weights_for_loss_fn = get_weights_for_loss

super().__init__(graph_float,
graph_quant,
Expand All @@ -92,40 +89,41 @@ def _get_total_grad_steps():
representative_data_gen_fn=representative_data_gen,
hessian_info_service=hessian_info_service)

self.loss_list = []
self.input_scale = 1
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
Logger.critical("Input scale mismatch between float and GPTQ networks. "
"Ensure both networks have matching input scales.") # pragma: no cover
else:
self.input_scale = self.gptq_user_info.input_scale

trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters(
self.fxp_model,
add_bias=self.gptq_config.train_bias)

self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
self.fxp_weights_list)):
Logger.critical("GPTQ: Number of comparison points, layers with trainable weights, "
"and float vs. quantized weights for loss calculation do not match. "
"Verify consistency across these parameters for successful GPTQ training.")

self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
trainable_bias,
trainable_threshold)
hessian_cfg = self.gptq_config.hessian_weights_config

self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
if self.use_sample_layer_attention:
# normalization is currently not supported, make sure the config reflects it.
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
raise NotImplementedError()
self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
else:
self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)
# self.loss_list = []
# self.input_scale = 1
# if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
# Logger.critical("Input scale mismatch between float and GPTQ networks. "
# "Ensure both networks have matching input scales.") # pragma: no cover
# else:
# self.input_scale = self.gptq_user_info.input_scale

# trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters(
# self.fxp_model,
# add_bias=self.gptq_config.train_bias)


# self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
# if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
# self.fxp_weights_list)):
# Logger.critical("GPTQ: Number of comparison points, layers with trainable weights, "
# "and float vs. quantized weights for loss calculation do not match. "
# "Verify consistency across these parameters for successful GPTQ training.")

# self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
# trainable_bias,
# trainable_threshold)
# hessian_cfg = self.gptq_config.hessian_weights_config

# self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
# if self.use_sample_layer_attention:
# # normalization is currently not supported, make sure the config reflects it.
# if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
# raise NotImplementedError()
# self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
# else:
# self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)


self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps, PytorchSoftQuantizerRegularization, PytorchLinearAnnealingScheduler)

def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
"""
Expand Down

0 comments on commit 10b80e1

Please sign in to comment.