diff --git a/model_compression_toolkit/core/common/framework_implementation.py b/model_compression_toolkit/core/common/framework_implementation.py index f34ac4d0a..ec0f64979 100644 --- a/model_compression_toolkit/core/common/framework_implementation.py +++ b/model_compression_toolkit/core/common/framework_implementation.py @@ -67,7 +67,6 @@ def get_trace_hessian_calculator(self, raise NotImplemented(f'{self.__class__.__name__} have to implement the ' f'framework\'s get_trace_hessian_calculator method.') # pragma: no cover - @abstractmethod def to_numpy(self, tensor: Any) -> np.ndarray: """ @@ -387,21 +386,20 @@ def get_node_distance_fn(self, layer_class: type, @abstractmethod - def is_node_compatible_for_metric_outputs(self, - node: BaseNode) -> bool: + def is_output_node_compatible_for_hessian_score_computation(self, + node: BaseNode) -> bool: """ - Checks and returns whether the given node is compatible as output for metric computation - purposes and gradient-based weights calculation. + Checks and returns whether the given node is compatible as output for Hessian-based information computation. Args: node: A BaseNode object. - Returns: Whether the node is compatible as output for metric computation or not. + Returns: Whether the node is compatible as output for Hessian-based information computation. """ raise NotImplemented(f'{self.__class__.__name__} have to implement the ' - f'framework\'s is_node_compatible_for_metric_outputs method.') # pragma: no cover + f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover @abstractmethod def get_node_mac_operations(self, diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_utils.py b/model_compression_toolkit/core/common/hessian/hessian_info_utils.py index a69641f3a..7f981ba1a 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_utils.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_utils.py @@ -13,63 +13,22 @@ # limitations under the License. # ============================================================================== from typing import List +import numpy as np +from model_compression_toolkit.constants import EPS -from model_compression_toolkit.constants import EPS, HESSIAN_OUTPUT_ALPHA - -def normalize_weights(trace_hessian_approximations: List, - outputs_indices: List[int], - alpha: float = HESSIAN_OUTPUT_ALPHA) -> List[float]: +def normalize_weights(hessian_approximations: List) -> np.ndarray: """ - Normalize trace Hessian approximations. Output layers or layers after the model's considered output layers - are assigned a constant normalized value. Other layers' weights are normalized by dividing the - trace Hessian approximations value by the sum of all other values. + Normalize Hessian information approximations by dividing the trace Hessian approximations value by the sum of all + other values. Args: - trace_hessian_approximations: Approximated average jacobian-based weights for each interest point. - outputs_indices: Indices of all nodes considered as outputs. - alpha: Multiplication factor. + hessian_approximations: Approximated average Hessian-based scores for each interest point. Returns: - Normalized list of trace Hessian approximations for each interest point. - """ - if len(trace_hessian_approximations)==1: - return [1.] - - sum_without_outputs = sum( - [trace_hessian_approximations[i] for i in range(len(trace_hessian_approximations)) if i not in outputs_indices]) - normalized_grads_weights = [_get_normalized_weight(grad, - i, - sum_without_outputs, - outputs_indices, - alpha) - for i, grad in enumerate(trace_hessian_approximations)] - - return normalized_grads_weights - - -def _get_normalized_weight(grad: float, - i: int, - sum_without_outputs: float, - outputs_indices: List[int], - alpha: float) -> float: + Normalized list of Hessian info approximations for each interest point. """ - Normalizes the node's trace Hessian approximation value. If it is an output or output - replacement node than the normalized value is a constant, otherwise, it is normalized - by dividing with the sum of all trace Hessian approximations values. - - Args: - grad: The approximation value. - i: The index of the node in the sorted interest points list. - sum_without_outputs: The sum of all approximations of nodes that are not considered outputs. - outputs_indices: A list of indices of nodes that consider outputs. - alpha: A multiplication factor. + scores_vec = np.asarray(hessian_approximations) - Returns: A normalized trace Hessian approximation. - - """ + return scores_vec / (np.sum(scores_vec) + EPS) - if i in outputs_indices: - return alpha / len(outputs_indices) - else: - return ((1 - alpha) * grad / (sum_without_outputs + EPS)) diff --git a/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py b/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py index 510810e52..90d79edcf 100644 --- a/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py +++ b/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py @@ -47,9 +47,12 @@ def __init__(self, """ self.graph = graph + for output_node in graph.get_outputs(): - if not fw_impl.is_node_compatible_for_metric_outputs(output_node.node): - Logger.error(f"All graph outputs should support metric outputs, but node {output_node.node} was found with layer type {output_node.node.type}.") + if not fw_impl.is_output_node_compatible_for_hessian_score_computation(output_node.node): + Logger.error(f"All graph outputs should support Hessian computation, but node {output_node.node} " + f"was found with layer type {output_node.node.type}. " + f"Try to run MCT without Hessian info computation.") self.input_images = fw_impl.to_tensor(input_images) self.num_iterations_for_approximation = num_iterations_for_approximation diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py index 2199b0015..671ff537f 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py @@ -31,7 +31,6 @@ def __init__(self, configuration_overwrite: List[int] = None, num_interest_points_factor: float = 1.0, use_grad_based_weights: bool = True, - output_grad_factor: float = 0.1, norm_weights: bool = True, refine_mp_solution: bool = True, metric_normalization_threshold: float = 1e10): @@ -46,8 +45,7 @@ def __init__(self, num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model. configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one. num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric. - use_grad_based_weights (bool): Whether to use gradient-based weights for weighted average distance metric computation. - output_grad_factor (float): A tuning parameter to be used for gradient-based weights. + use_grad_based_weights (bool): Whether to use Hessian-based scores for weighted average distance metric computation. norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1). refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not. metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues. @@ -67,13 +65,8 @@ def __init__(self, self.num_interest_points_factor = num_interest_points_factor self.use_grad_based_weights = use_grad_based_weights - self.output_grad_factor = output_grad_factor self.norm_weights = norm_weights - if use_grad_based_weights is True: - Logger.info(f"Using gradient-based weights for mixed-precision distance metric with tuning factor " - f"{output_grad_factor}") - self.metric_normalization_threshold = metric_normalization_threshold diff --git a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py index 0b0a88213..a157ff985 100644 --- a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +++ b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py @@ -15,7 +15,7 @@ import copy import numpy as np -from typing import Callable, Any, List +from typing import Callable, Any, List, Tuple from model_compression_toolkit.constants import AXIS, HESSIAN_OUTPUT_ALPHA from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfigV2 @@ -82,22 +82,27 @@ def __init__(self, f" an HessianInfoService object must be provided but is {hessian_info_service}") self.hessian_info_service = hessian_info_service - # Get interest points for distance measurement and a list of sorted configurable nodes names self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names() + + # Get interest points and output points set for distance measurement and set other helper datasets + # We define a separate set of output nodes of the model for the purpose of sensitivity computation. self.interest_points = get_mp_interest_points(graph, fw_impl.count_node_for_mixed_precision_interest_points, quant_config.num_interest_points_factor) - self.outputs_replacement_nodes = None - self.output_nodes_indices = None - if self.quant_config.use_grad_based_weights is True: - # Getting output replacement (if needed) - if a model's output layer is not compatible for the task of - # gradients computation then we find a predecessor layer which is compatible, - # add it to the set of interest points and use it for the gradients' computation. - # Note that we need to modify the set of interest points before building the models, - # therefore, it is separated from the part where we compute the actual gradient weights. - self.outputs_replacement_nodes = get_output_replacement_nodes(graph, fw_impl) - self.output_nodes_indices = self._update_ips_with_outputs_replacements() + self.ips_distance_fns, self.ips_batch_axis = self._init_metric_points_lists(self.interest_points) + + self.output_points = get_output_nodes_for_metric(graph) + self.out_ps_distance_fns, self.out_ps_batch_axis = self._init_metric_points_lists(self.output_points) + + # Setting lists with relative position of the interest points + # and output points in the list of all mp model activation tensors + graph_sorted_nodes = self.graph.get_topo_sorted_nodes() + all_out_tensors_indices = [graph_sorted_nodes.index(n) for n in self.interest_points + self.output_points] + global_ipts_indices = [graph_sorted_nodes.index(n) for n in self.interest_points] + global_out_pts_indices = [graph_sorted_nodes.index(n) for n in self.output_points] + self.ips_act_indices = [all_out_tensors_indices.index(i) for i in global_ipts_indices] + self.out_ps_act_indices = [all_out_tensors_indices.index(i) for i in global_out_pts_indices] # Build a mixed-precision model which can be configured to use different bitwidth in different layers. # And a baseline model. @@ -117,16 +122,35 @@ def __init__(self, # Initiating baseline_tensors_list since it is not initiated in SensitivityEvaluationManager init. self._init_baseline_tensors_list() - # Computing gradient-based weights for weighted average distance metric computation (only if requested), + # Computing Hessian-based scores for weighted average distance metric computation (only if requested), # and assigning distance_weighting method accordingly. - self.interest_points_gradients = None - if self.quant_config.use_grad_based_weights is True: - assert self.outputs_replacement_nodes is not None and self.output_nodes_indices is not None, \ - f"{self.outputs_replacement_nodes} and {self.output_nodes_indices} " \ - f"should've been assigned before computing the gradient-based weights." + self.interest_points_hessians = None + if self.quant_config.use_grad_based_weights is True: + self.interest_points_hessians = self._compute_hessian_based_scores() + self.quant_config.distance_weighting_method = lambda d: self.interest_points_hessians + + def _init_metric_points_lists(self, points: List[BaseNode]) -> Tuple[List[Callable], List[int]]: + """ + Initiates required lists for future use when computing the sensitivity metric. + Each point on which the metric is computed uses a dedicated distance function based on its type. + In addition, all distance functions preform batch computation, so the batch axis is needed for each node. + + Args: + points: The set of nodes in the graph for which we need to initiate the lists. - self.interest_points_gradients = self._compute_gradient_based_weights() - self.quant_config.distance_weighting_method = lambda d: self.interest_points_gradients + Returns: A lists with distance functions and a list batch axis for each node. + + """ + distance_fns_list = [] + batch_axis_list = [] + for n in points: + distance_fns_list.append(self.fw_impl.get_node_distance_fn( + layer_class=n.layer_class, + framework_attrs=n.framework_attr, + compute_distance_fn=self.quant_config.compute_distance_fn)) + batch_axis_list.append(n.framework_attr.get(AXIS) if not isinstance(n, FunctionalNode) + else n.op_call_kwargs.get(AXIS)) + return distance_fns_list, batch_axis_list def compute_metric(self, mp_model_configuration: List[int], @@ -151,15 +175,16 @@ def compute_metric(self, self._configure_bitwidths_model(mp_model_configuration, node_idx) - # Compute the distance matrix - distance_matrix = self._build_distance_matrix() + # Compute the distance metric + ipts_distances, out_pts_distances = self._compute_distance() # Configure MP model back to the same configuration as the baseline model if baseline provided if baseline_mp_configuration is not None: self._configure_bitwidths_model(baseline_mp_configuration, node_idx) - return self._compute_mp_distance_measure(distance_matrix, self.quant_config.distance_weighting_method) + return self._compute_mp_distance_measure(ipts_distances, out_pts_distances, + self.quant_config.distance_weighting_method) def _init_baseline_tensors_list(self): """ @@ -188,21 +213,21 @@ def _build_models(self) -> Any: model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph, mode=ModelBuilderMode.MIXEDPRECISION, - append2output=self.interest_points, + append2output=self.interest_points + self.output_points, fw_info=self.fw_info) # Build a baseline model. baseline_model, _ = self.fw_impl.model_builder(evaluation_graph, mode=ModelBuilderMode.FLOAT, - append2output=self.interest_points) + append2output=self.interest_points + self.output_points) return baseline_model, model_mp, conf_node2layers - def _compute_gradient_based_weights(self) -> np.ndarray: + def _compute_hessian_based_scores(self) -> np.ndarray: """ - Compute gradient-based weights using trace Hessian approximations for each interest point. + Compute Hessian-based scores for each interest point. - Returns: A vector of weights, one for each interest point, + Returns: A vector of scores, one for each interest point, to be used for the distance metric weighted average computation. """ @@ -239,9 +264,8 @@ def _compute_gradient_based_weights(self) -> np.ndarray: approx_by_image_per_interest_point.append(compare_point_to_trace_hessian_approximations[target_node][image_idx][0]) if self.quant_config.norm_weights: - approx_by_image_per_interest_point = hessian_utils.normalize_weights(trace_hessian_approximations=approx_by_image_per_interest_point, - outputs_indices=self.output_nodes_indices, - alpha=HESSIAN_OUTPUT_ALPHA) + approx_by_image_per_interest_point = \ + hessian_utils.normalize_weights(hessian_approximations=approx_by_image_per_interest_point) # Append the approximations for the current image to the main list approx_by_image.append(approx_by_image_per_interest_point) @@ -299,47 +323,42 @@ def _configure_node_bitwidth(self, for current_layer in layers_to_config: self.set_layer_to_bitwidth(current_layer, mp_model_configuration[node_idx_to_configure]) - def _compute_distance_matrix(self, + def _compute_points_distance(self, baseline_tensors: List[Any], - mp_tensors: List[Any]): + mp_tensors: List[Any], + points_distance_fns: List[Callable], + points_batch_axis: List[int]): """ - Compute the distance between the MP model's outputs and the baseline model's outputs + Compute the distance on the given set of points outputs between the MP model and the baseline model for each image in the batch that was inferred. + Args: - baseline_tensors: Baseline model's output tensors. - mp_tensors: MP model's output tensors. + baseline_tensors: Baseline model's output tensors of the given points. + mp_tensors: MP model's output tensors pf the given points. + points_distance_fns: A list with distance function to compute the distance between each given + point's output tensors. + points_batch_axis: A list with the matching batch axis of each given point's output tensors. + Returns: - A distance matrix that maps each node's index to the distance between this node's output + A distance vector that maps each node's index in the given nodes list to the distance between this node's output and the baseline model's output for all images that were inferred. """ - assert len(baseline_tensors) == len(self.interest_points) - num_interest_points = len(baseline_tensors) - num_samples = len(baseline_tensors[0]) - distance_matrix = np.ndarray((num_interest_points, num_samples)) - - for i in range(num_interest_points): - point_node = self.interest_points[i] - point_distance_fn = \ - self.fw_impl.get_node_distance_fn(layer_class=point_node.layer_class, - framework_attrs=point_node.framework_attr, - compute_distance_fn=self.quant_config.compute_distance_fn) + distance_v = [fn(x, y, batch=True, axis=axis) for fn, x, y, axis + in zip(points_distance_fns, baseline_tensors, mp_tensors, points_batch_axis)] - axis = point_node.framework_attr.get(AXIS) if not isinstance(point_node, FunctionalNode) \ - else point_node.op_call_kwargs.get(AXIS) + return np.asarray(distance_v) - distance_matrix[i] = point_distance_fn(baseline_tensors[i], mp_tensors[i], batch=True, axis=axis) - - return distance_matrix - - def _build_distance_matrix(self): + def _compute_distance(self) -> Tuple[np.ndarray, np.ndarray]: """ - Builds a matrix that contains the distances between the baseline and MP models for each interest point. - Returns: A distance matrix. + Computing the interest points distance and the output points distance, and using them to build a + unified distance vector. + + Returns: A distance vector. """ - # List of distance matrices. We create a distance matrix for each sample from the representative_data_gen - # and merge all of them eventually. - distance_matrices = [] + + ipts_per_batch_distance = [] + out_pts_per_batch_distance = [] # Compute the distance matrix for num_of_images images. for images, baseline_tensors in zip(self.images_batches, self.baseline_tensors_list): @@ -347,32 +366,58 @@ def _build_distance_matrix(self): mp_tensors = self.fw_impl.sensitivity_eval_inference(self.model_mp, images) mp_tensors = self.fw_impl.to_numpy(mp_tensors) - # Build distance matrix: similarity between the baseline model to the float model + # Compute distance: similarity between the baseline model to the float model # in every interest point for every image in the batch. - distance_matrices.append(self._compute_distance_matrix(baseline_tensors, mp_tensors)) + ips_distance = self._compute_points_distance([baseline_tensors[i] for i in self.ips_act_indices], + [mp_tensors[i] for i in self.ips_act_indices], + self.ips_distance_fns, + self.ips_batch_axis) + outputs_distance = self._compute_points_distance([baseline_tensors[i] for i in self.out_ps_act_indices], + [mp_tensors[i] for i in self.out_ps_act_indices], + self.out_ps_distance_fns, + self.out_ps_batch_axis) + + # Extending the dimensions for the concatenation at the end in case we need to + ips_distance = ips_distance if len(ips_distance.shape) > 1 else ips_distance[:, None] + outputs_distance = outputs_distance if len(outputs_distance.shape) > 1 else outputs_distance[:, None] + ipts_per_batch_distance.append(ips_distance) + out_pts_per_batch_distance.append(outputs_distance) # Merge all distance matrices into a single distance matrix. - distance_matrix = np.concatenate(distance_matrices, axis=1) + ipts_distances = np.concatenate(ipts_per_batch_distance, axis=1) + out_pts_distances = np.concatenate(out_pts_per_batch_distance, axis=1) - return distance_matrix + return ipts_distances, out_pts_distances @staticmethod - def _compute_mp_distance_measure(distance_matrix: np.ndarray, metrics_weights_fn: Callable) -> float: + def _compute_mp_distance_measure(ipts_distances: np.ndarray, + out_pts_distances: np.ndarray, + metrics_weights_fn: Callable) -> float: """ Computes the final distance value out of a distance matrix. Args: - distance_matrix: A matrix that contains the distances between the baseline and MP models + ipts_distances: A matrix that contains the distances between the baseline and MP models for each interest point. - metrics_weights_fn: + out_pts_distances: A matrix that contains the distances between the baseline and MP models + for each output point. + metrics_weights_fn: A callable that produces the scores to compute weighted distance for interest points. Returns: Distance value. """ - # Compute the distance between the baseline model's outputs and the MP model's outputs. - # The distance is the mean of distances over all images in the batch that was inferred. - mean_distance_per_layer = distance_matrix.mean(axis=1) - # Use weights such that every layer's distance is weighted differently (possibly). - return np.average(mean_distance_per_layer, weights=metrics_weights_fn(distance_matrix)) + mean_ipts_distance = 0 + if len(ipts_distances) > 0: + mean_distance_per_layer = ipts_distances.mean(axis=1) + + # Use weights such that every layer's distance is weighted differently (possibly). + mean_ipts_distance = np.average(mean_distance_per_layer, weights=metrics_weights_fn(ipts_distances)) + + mean_output_distance = 0 + if len(out_pts_distances) > 0: + mean_distance_per_output = out_pts_distances.mean(axis=1) + mean_output_distance = np.average(mean_distance_per_output) + + return mean_output_distance + mean_ipts_distance def _get_images_batches(self, num_of_images: int) -> List[Any]: """ @@ -406,37 +451,14 @@ def _get_images_batches(self, num_of_images: int) -> List[Any]: f'only {samples_count} were generated') return images_batches - def _update_ips_with_outputs_replacements(self): - """ - Updates the list of interest points with the set of pre-calculated replacement outputs. - Also, returns the indices of all output nodes (original, replacements and nodes in between them) in a - topological sorted interest points list (for later use in gradients computation and normalization). - - Returns: A list of indices of the output nodes in the sorted interest points list. - - """ - - assert self.outputs_replacement_nodes is not None, \ - "Trying to update interest points list with new output nodes but outputs_replacement_nodes list is None." - - replacement_outputs_to_ip = [r_node for r_node in self.outputs_replacement_nodes if - r_node not in self.interest_points] - updated_interest_points = self.interest_points + replacement_outputs_to_ip - - # Re-sort interest points in a topological order according to the graph's sort - self.interest_points = [n for n in self.graph.get_topo_sorted_nodes() if n in updated_interest_points] - - output_indices = [self.interest_points.index(n.node) for n in self.graph.get_outputs()] - replacement_indices = [self.interest_points.index(n) for n in self.outputs_replacement_nodes] - return list(set(output_indices + replacement_indices)) - def get_mp_interest_points(graph: Graph, interest_points_classifier: Callable, num_ip_factor: float) -> List[BaseNode]: """ Gets a list of interest points for the mixed precision metric computation. - The list is constructed from a filtered set of the convolutions nodes in the graph. + The list is constructed from a filtered set of nodes in the graph. + Note that the output layers are separated from the interest point set for metric computation purposes. Args: graph: Graph to search for its MP configuration. @@ -455,38 +477,26 @@ def get_mp_interest_points(graph: Graph, # We add output layers of the model to interest points # in order to consider the model's output in the distance metric computation (and also to make sure # all configurable layers are included in the configured mp model for metric computation purposes) - output_nodes = [n.node for n in graph.get_outputs() if n.node not in interest_points_nodes and - (n.node.is_weights_quantization_enabled() or - n.node.is_activation_quantization_enabled())] - interest_points = interest_points_nodes + output_nodes + output_nodes = get_output_nodes_for_metric(graph) + + interest_points = [n for n in interest_points_nodes if n not in output_nodes] return interest_points -def get_output_replacement_nodes(graph: Graph, - fw_impl: Any) -> List[BaseNode]: +def get_output_nodes_for_metric(graph: Graph) -> List[BaseNode]: """ - If a model's output node is not compatible for the task of gradients computation we need to find a predecessor - node in the model's graph representation which is compatible and add it to the set of interest points and use it - for the gradients' computation. This method searches for this predecessor node for each output of the model. + Returns a list of output nodes that are also quantized (either weights or activation) + to be used as a set of output points in the distance metric computation. Args: - graph: Graph to search for replacement output nodes. - fw_impl: FrameworkImplementation object with a specific framework methods implementation. + graph: Graph to search for its MP configuration. - Returns: A list of output replacement nodes. + Returns: A list of output nodes. """ - replacement_outputs = [] - for n in graph.get_outputs(): - prev_node = n.node - while not fw_impl.is_node_compatible_for_metric_outputs(prev_node): - prev_node = graph.get_prev_nodes(n.node) - assert len(prev_node) == 1, "A none MP compatible output node has multiple inputs, " \ - "which is incompatible for metric computation." - prev_node = prev_node[0] - replacement_outputs.append(prev_node) - return replacement_outputs + return [n.node for n in graph.get_outputs() if (n.node.is_weights_quantization_enabled() or + n.node.is_activation_quantization_enabled())] def bound_num_interest_points(sorted_ip_list: List[BaseNode], num_ip_factor: float) -> List[BaseNode]: diff --git a/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py b/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py index 599b71148..ce9c4378b 100644 --- a/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py +++ b/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py @@ -65,7 +65,7 @@ def compute(self) -> List[float]: List[float]: Approximated trace of the Hessian for an interest point. """ if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR: - output_list = self._get_model_output_replacement() + output_list = [n.node for n in self.graph.get_outputs()] # Record operations for automatic differentiation with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g: @@ -145,51 +145,6 @@ def compute(self) -> List[float]: else: Logger.error(f"{self.hessian_request.granularity} is not supported for Keras activation hessian's trace approx calculator") - - def _update_ips_with_outputs_replacements(self, - outputs_replacement_nodes: List[BaseNode], - interest_points: List[BaseNode]): - """ - Updates the list of interest points with the set of pre-calculated replacement outputs. - Also, returns the indices of all output nodes (original, replacements and nodes in between them) in a - topological sorted interest points list (for later use in gradients computation and normalization). - - Returns: A list of indices of the output nodes in the sorted interest points list. - - """ - - replacement_outputs_to_ip = [r_node for r_node in outputs_replacement_nodes if - r_node not in interest_points] - updated_interest_points = interest_points + replacement_outputs_to_ip - - # Re-sort interest points in a topological order according to the graph's sort - interest_points = [n for n in self.graph.get_topo_sorted_nodes() if n in updated_interest_points] - - output_indices = [interest_points.index(n.node) for n in self.graph.get_outputs()] - replacement_indices = [interest_points.index(n) for n in outputs_replacement_nodes] - return list(set(output_indices + replacement_indices)) - - def _get_model_output_replacement(self) -> List[str]: - """ - If a model's output node is not compatible for the task of gradients computation we need to find a predecessor - node in the model's graph representation which is compatible and use it for the gradients' computation. - This method searches for this predecessor node for each output of the model. - - Returns: A list of output replacement nodes. - - """ - - replacement_outputs = [] - for n in self.graph.get_outputs(): - prev_node = n.node - while not self.fw_impl.is_node_compatible_for_metric_outputs(prev_node): - prev_node = self.graph.get_prev_nodes(prev_node) - assert len(prev_node) == 1, "A none compatible output node has multiple inputs, " \ - "which is incompatible for metric computation." - prev_node = prev_node[0] - replacement_outputs.append(prev_node) - return replacement_outputs - def _get_model_outputs_for_single_image(self, output_list: List[str], gradient_tape: tf.GradientTape) -> Tuple[List[tf.Tensor], List[tf.Tensor]]: diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index e84388634..0a937f290 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -500,25 +500,23 @@ def get_trace_hessian_calculator(self, else: Logger.error(f"Keras does not support hessian mode of {trace_hessian_request.mode}") - - def is_node_compatible_for_metric_outputs(self, - node: BaseNode) -> Any: + def is_output_node_compatible_for_hessian_score_computation(self, + node: BaseNode) -> Any: """ - Checks and returns whether the given node is compatible as output for metric computation - purposes and gradient-based weights calculation. + Checks and returns whether the given node is compatible as output for Hessian-based information computation. Args: node: A BaseNode object. - Returns: Whether the node is compatible as output for metric computation or not. + Returns: Whether the node is compatible as output for Hessian-based information computation. """ if node.layer_class == TFOpLambda: node_attr = getattr(node, 'framework_attr', None) - if node_attr is not None and (ARGMAX in node_attr[LAYER_NAME] or SOFTMAX in node_attr[LAYER_NAME]): + if node_attr is not None and (ARGMAX in node_attr[LAYER_NAME]): return False - elif node.layer_class in [tf.nn.softmax, tf.keras.layers.Softmax, tf.math.argmax]: + elif node.layer_class in [tf.math.argmax]: return False return True diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 8480f40f8..6a39affa9 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -430,16 +430,16 @@ def get_node_distance_fn(self, layer_class: type, return compute_cs return compute_mse - def is_node_compatible_for_metric_outputs(self, - node: BaseNode) -> bool: + def is_output_node_compatible_for_hessian_score_computation(self, + node: BaseNode) -> bool: """ - Checks and returns whether the given node is compatible as output for metric computation - purposes and gradient-based weights calculation. + Checks and returns whether the given node is compatible as output for Hessian-based information computation. + Args: node: A BaseNode object. - Returns: Whether the node is compatible as output for metric computation or not. + Returns: Whether the node is compatible as output for Hessian-based information computation. """ @@ -539,4 +539,3 @@ def get_trace_hessian_calculator(self, input_images=input_images, fw_impl=self, num_iterations_for_approximation=num_iterations_for_approximation) - diff --git a/model_compression_toolkit/gptq/common/gptq_training.py b/model_compression_toolkit/gptq/common/gptq_training.py index 71ae33ad9..d6dce02e2 100644 --- a/model_compression_toolkit/gptq/common/gptq_training.py +++ b/model_compression_toolkit/gptq/common/gptq_training.py @@ -205,7 +205,7 @@ def _process_hessian_approximations(self, approximations: Dict[BaseNode, List[Li for image_idx in range(self.gptq_config.hessian_weights_config.hessians_num_samples): approx_by_interest_point = self._get_approximations_by_interest_point(approximations, image_idx) if self.gptq_config.hessian_weights_config.norm_weights: - approx_by_interest_point = hessian_utils.normalize_weights(approx_by_interest_point, [], alpha=0) + approx_by_interest_point = hessian_utils.normalize_weights(approx_by_interest_point) trace_hessian_approx_by_image.append(approx_by_interest_point) return trace_hessian_approx_by_image @@ -310,28 +310,6 @@ def update_graph(self) -> Graph: raise NotImplemented(f'{self.__class__.__name__} have to implement the ' f'framework\'s update_graph method.') # pragma: no cover - def _get_model_output_replacement(self) -> List[BaseNode]: - """ - If a model's output node is not compatible for the task of gradients computation we need to find a predecessor - node in the model's graph representation which is compatible and use it for the gradients' computation. - This method searches for this predecessor node for each output of the model. - - Returns: A list of output replacement nodes. - - """ - - replacement_outputs = [] - for n in self.graph_float.get_outputs(): - prev_node = n.node - while not self.fw_impl.is_node_compatible_for_metric_outputs(prev_node): - prev_node = self.graph_float.get_prev_nodes(n.node) - assert len(prev_node) == 1, "A none compatible output node has multiple inputs, " \ - "which is incompatible for metric computation." - prev_node = prev_node[0] - replacement_outputs.append(prev_node) - return replacement_outputs - - def gptq_training(graph_float: Graph, graph_quant: Graph, gptq_config: GradientPTQConfig, diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py index 96f38454c..b91cae883 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py @@ -246,12 +246,11 @@ def create_networks(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) - assert (quantization_info.mixed_precision_cfg == [1, 1]).all() - for i in range(32): # quantized per channel - self.unit_test.assertTrue( - np.unique(conv_layers[0].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 16) + self.unit_test.assertTrue((quantization_info.mixed_precision_cfg != 0).any()) + for i in range(32): # quantized per channel self.unit_test.assertTrue( + np.unique(conv_layers[0].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 16 or np.unique(conv_layers[1].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 16) # Verify final KPI diff --git a/tests/keras_tests/function_tests/test_hessian_info_calculator_weights.py b/tests/keras_tests/function_tests/test_hessian_info_calculator.py similarity index 60% rename from tests/keras_tests/function_tests/test_hessian_info_calculator_weights.py rename to tests/keras_tests/function_tests/test_hessian_info_calculator.py index 968fd227c..b3d3a4b46 100644 --- a/tests/keras_tests/function_tests/test_hessian_info_calculator_weights.py +++ b/tests/keras_tests/function_tests/test_hessian_info_calculator.py @@ -42,6 +42,7 @@ def basic_model(input_shape, layer): outputs = ReLU()(x_bn) return keras.Model(inputs=inputs, outputs=outputs) + def reused_model(input_shape): reused_layer = Conv2D(filters=3, kernel_size=2, padding='same') inputs = Input(shape=input_shape[1:]) @@ -49,6 +50,7 @@ def reused_model(input_shape): x = reused_layer(x) return keras.Model(inputs=inputs, outputs=x) + def get_multiple_outputs_model(input_shape): inputs = Input(shape=input_shape[1:]) x = Conv2D(filters=2, kernel_size=3)(inputs) @@ -57,6 +59,7 @@ def get_multiple_outputs_model(input_shape): out2 = Conv2D(2, 4)(out1) return keras.Model(inputs=inputs, outputs=[out1, out2]) + def get_multiple_outputs_to_intermediate_node_model(input_shape): inputs = Input(shape=input_shape[1:]) x = Conv2D(filters=2, kernel_size=3)(inputs) @@ -66,6 +69,7 @@ def get_multiple_outputs_to_intermediate_node_model(input_shape): outputs = x_split[0] + x_split[1] return keras.Model(inputs=inputs, outputs=outputs) + def get_multiple_inputs_model(input_shape): inputs = Input(shape=input_shape[1:]) inputs2 = Input(shape=input_shape[1:]) @@ -76,33 +80,33 @@ def get_multiple_inputs_model(input_shape): outputs = x+x2 return keras.Model(inputs=[inputs, inputs2], outputs=outputs) + def representative_dataset(input_shape, num_of_inputs=1): yield [np.random.randn(*input_shape).astype(np.float32)] * num_of_inputs -class TestHessianInfoCalculatorWeights(unittest.TestCase): - - def _fetch_scores(self, hessian_info, target_node, granularity, num_scores=1): - request = hessian_common.TraceHessianRequest(mode=hessian_common.HessianMode.WEIGHTS, +class TestHessianInfoCalculatorBase(unittest.TestCase): + def _fetch_scores(self, hessian_info, target_node, granularity, mode, num_scores=1): + request = hessian_common.TraceHessianRequest(mode=mode, granularity=granularity, target_node=target_node) info = hessian_info.fetch_hessian(request, num_scores) assert len(info) == num_scores, f"fetched {num_scores} score but {len(info)} scores were fetched" return np.mean(np.stack(info), axis=0) - def _test_score_shape(self, hessian_service, interest_point, granularity, expected_shape, num_scores=1): + def _test_score_shape(self, hessian_service, interest_point, granularity, mode, expected_shape, num_scores=1): score = self._fetch_scores(hessian_info=hessian_service, target_node=interest_point, # linear op granularity=granularity, + mode=mode, num_scores=num_scores) self.assertTrue(isinstance(score, np.ndarray), f"scores expected to be a numpy array but is {type(score)}") self.assertTrue(score.shape == expected_shape, f"Tensor shape is expected to be {expected_shape} but has shape {score.shape}") # per tensor return score - def test_conv2d_granularity(self): - input_shape = (1, 8, 8, 3) - in_model = basic_model(input_shape, layer=Conv2D(filters=2, kernel_size=3)) + def _setup(self, layer, input_shape=(1, 8, 8, 3)): + in_model = basic_model(input_shape, layer=layer) keras_impl = KerasImplementation() _repr_dataset = functools.partial(representative_dataset, input_shape=input_shape) @@ -112,6 +116,13 @@ def test_conv2d_granularity(self): _repr_dataset, generate_keras_tpc) + return graph, _repr_dataset, keras_impl + + +class TestHessianInfoCalculatorWeights(TestHessianInfoCalculatorBase): + + def test_conv2d_granularity(self): + graph, _repr_dataset, keras_impl = self._setup(layer=Conv2D(filters=2, kernel_size=3)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] hessian_service = hessian_common.HessianInfoService(graph=graph, @@ -120,29 +131,22 @@ def test_conv2d_granularity(self): self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(1,)) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(2,)) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(3, 3, 3, 2)) del hessian_service def test_dense_granularity(self): - input_shape = (1, 8) - in_model = basic_model(input_shape, layer=Dense(2)) - keras_impl = KerasImplementation() - _repr_dataset = functools.partial(representative_dataset, - input_shape=input_shape) - graph = prepare_graph_with_configs(in_model, - keras_impl, - DEFAULT_KERAS_INFO, - _repr_dataset, - generate_keras_tpc) - + graph, _repr_dataset, keras_impl = self._setup(layer=Dense(2), input_shape=(1, 8)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] hessian_service = hessian_common.HessianInfoService(graph=graph, @@ -152,29 +156,22 @@ def test_dense_granularity(self): self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(1,)) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(2,)) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(8, 2)) del hessian_service def test_conv2dtranspose_granularity(self): - input_shape = (1, 8, 8, 3) - in_model = basic_model(input_shape, layer=Conv2DTranspose(filters=2, kernel_size=3)) - keras_impl = KerasImplementation() - _repr_dataset = functools.partial(representative_dataset, - input_shape=input_shape) - graph = prepare_graph_with_configs(in_model, - keras_impl, - DEFAULT_KERAS_INFO, - _repr_dataset, - generate_keras_tpc) - + graph, _repr_dataset, keras_impl = self._setup(layer=Conv2DTranspose(filters=2, kernel_size=3)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] hessian_service = hessian_common.HessianInfoService(graph=graph, @@ -184,29 +181,22 @@ def test_conv2dtranspose_granularity(self): self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(1,)) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(2,)) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(3, 3, 2, 3)) del hessian_service def test_depthwiseconv2d_granularity(self): - input_shape = (1, 8, 8, 3) - in_model = basic_model(input_shape, layer=DepthwiseConv2D(kernel_size=3)) - keras_impl = KerasImplementation() - _repr_dataset = functools.partial(representative_dataset, - input_shape=input_shape) - graph = prepare_graph_with_configs(in_model, - keras_impl, - DEFAULT_KERAS_INFO, - _repr_dataset, - generate_keras_tpc) - + graph, _repr_dataset, keras_impl = self._setup(layer=DepthwiseConv2D(kernel_size=3)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] hessian_service = hessian_common.HessianInfoService(graph=graph, @@ -216,14 +206,17 @@ def test_depthwiseconv2d_granularity(self): self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(1,)) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(3,)) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(3, 3, 3, 1)) del hessian_service @@ -251,27 +244,29 @@ def test_reused_layer(self): fw_impl=keras_impl) node1_approx = self._test_score_shape(hessian_service, interest_points[0], - granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, - expected_shape=(2, 2, 3, 3)) + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS, + expected_shape=(1,)) node2_approx = self._test_score_shape(hessian_service, interest_points[1], - granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, - expected_shape=(2, 2, 3, 3)) + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS, + expected_shape=(1,)) self.assertTrue(np.all(node1_approx==node2_approx), f'Approximations of nodes of a reused layer ' f'should be equal') node1_count = hessian_service.count_saved_info_of_request( hessian_common.TraceHessianRequest(target_node=interest_points[0], mode=hessian_common.HessianMode.WEIGHTS, - granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT)) + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)) self.assertTrue(node1_count == 1) node2_count = hessian_service.count_saved_info_of_request( hessian_common.TraceHessianRequest(target_node=interest_points[1], mode=hessian_common.HessianMode.WEIGHTS, - granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT)) + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)) self.assertTrue(node2_count == 1) - self.assertTrue(len(hessian_service.trace_hessian_request_to_score_list)==1) + self.assertTrue(len(hessian_service.trace_hessian_request_to_score_list) == 1) del hessian_service ######################################################### @@ -303,14 +298,17 @@ def _test_advanced_graph(self, float_model, _repr_dataset): self._test_score_shape(hessian_service, interest_points, granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(1,)) self._test_score_shape(hessian_service, interest_points, granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(2,)) self._test_score_shape(hessian_service, interest_points, granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, + mode=hessian_common.HessianMode.WEIGHTS, expected_shape=(3, 3, 3, 2)) del hessian_service @@ -337,3 +335,176 @@ def test_multiple_outputs_to_intermediate_node(self): _repr_dataset = functools.partial(representative_dataset, input_shape=input_shape) self._test_advanced_graph(in_model, _repr_dataset) + + +class TestHessianInfoCalculatorActivation(TestHessianInfoCalculatorBase): + + def test_conv2d_granularity(self): + graph, _repr_dataset, keras_impl = self._setup(layer=Conv2D(filters=2, kernel_size=3)) + sorted_graph_nodes = graph.get_topo_sorted_nodes() + interest_points = [n for n in sorted_graph_nodes] + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=_repr_dataset, + fw_impl=keras_impl) + self._test_score_shape(hessian_service, + interest_points[1], + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION, + expected_shape=(1,)) + + del hessian_service + + def test_dense_granularity(self): + graph, _repr_dataset, keras_impl = self._setup(layer=Dense(2), input_shape=(1, 8)) + sorted_graph_nodes = graph.get_topo_sorted_nodes() + interest_points = [n for n in sorted_graph_nodes] + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=_repr_dataset, + fw_impl=keras_impl) + + self._test_score_shape(hessian_service, + interest_points[1], + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION, + expected_shape=(1,)) + + del hessian_service + + def test_conv2dtranspose_granularity(self): + graph, _repr_dataset, keras_impl = self._setup(layer=Conv2DTranspose(filters=2, kernel_size=3)) + sorted_graph_nodes = graph.get_topo_sorted_nodes() + interest_points = [n for n in sorted_graph_nodes] + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=_repr_dataset, + fw_impl=keras_impl) + + self._test_score_shape(hessian_service, + interest_points[1], + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION, + expected_shape=(1,)) + + del hessian_service + + def test_depthwiseconv2d_granularity(self): + graph, _repr_dataset, keras_impl = self._setup(layer=DepthwiseConv2D(kernel_size=3)) + sorted_graph_nodes = graph.get_topo_sorted_nodes() + interest_points = [n for n in sorted_graph_nodes] + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=_repr_dataset, + fw_impl=keras_impl) + + self._test_score_shape(hessian_service, + interest_points[1], + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION, + expected_shape=(1,)) + + del hessian_service + + def test_reused_layer(self): + input_shape = (1, 8, 8, 3) + in_model = reused_model(input_shape) + _repr_dataset = functools.partial(representative_dataset, + input_shape=input_shape) + + keras_impl = KerasImplementation() + graph = prepare_graph_with_configs(in_model, + keras_impl, + DEFAULT_KERAS_INFO, + _repr_dataset, + generate_keras_tpc) + + sorted_graph_nodes = graph.get_topo_sorted_nodes() + + # Two nodes representing the same reused layer + interest_points = [n for n in sorted_graph_nodes if n.type == Conv2D] + self.assertTrue(len(interest_points)==2, f"Expected to find 2 Conv2D nodes but found {len(interest_points)}") + + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=_repr_dataset, + fw_impl=keras_impl) + node1_approx = self._test_score_shape(hessian_service, + interest_points[0], + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION, + expected_shape=(1,)) + node2_approx = self._test_score_shape(hessian_service, + interest_points[1], + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION, + expected_shape=(1,)) + + self.assertTrue(np.all(node1_approx == node2_approx), f'Approximations of nodes of a reused layer ' + f'should be equal') + + node1_count = hessian_service.count_saved_info_of_request( + hessian_common.TraceHessianRequest(target_node=interest_points[0], + mode=hessian_common.HessianMode.ACTIVATION, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)) + self.assertTrue(node1_count == 1) + + node2_count = hessian_service.count_saved_info_of_request( + hessian_common.TraceHessianRequest(target_node=interest_points[1], + mode=hessian_common.HessianMode.ACTIVATION, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)) + self.assertTrue(node2_count == 1) + self.assertTrue(len(hessian_service.trace_hessian_request_to_score_list) == 1) + + del hessian_service + + ######################################################### + # The following part checks different possible graph + # properties (#inputs/#outputs, for example). + ######################################################## + + def _test_advanced_graph(self, float_model, _repr_dataset): + ######################################################################## + # Since we want to test some models with different properties (e.g., multiple inputs/outputs) + # we can no longer assume we're fetching interest point #1 like in the linear ops + # tests. Instead, this function assumes the first Conv2D interest point is the interest point. + ####################################################################### + keras_impl = KerasImplementation() + graph = prepare_graph_with_configs(float_model, + keras_impl, + DEFAULT_KERAS_INFO, + _repr_dataset, + generate_keras_tpc) + + sorted_graph_nodes = graph.get_topo_sorted_nodes() + + # This test assumes the first Conv2D interest point is the node that + # we fetch its scores and test their shapes correctness. + interest_points = [n for n in sorted_graph_nodes if n.type==Conv2D][0] + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=_repr_dataset, + fw_impl=keras_impl) + self._test_score_shape(hessian_service, + interest_points, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION, + expected_shape=(1,)) + + del hessian_service + + def test_multiple_inputs(self): + input_shape = (1, 8, 8, 3) + in_model = get_multiple_inputs_model(input_shape) + _repr_dataset = functools.partial(representative_dataset, + input_shape=input_shape, + num_of_inputs=2) + self._test_advanced_graph(in_model, _repr_dataset) + + def test_multiple_outputs(self): + input_shape = (1, 8, 8, 3) + in_model = get_multiple_outputs_model(input_shape) + _repr_dataset = functools.partial(representative_dataset, + input_shape=input_shape) + self._test_advanced_graph(in_model, _repr_dataset) + + def test_multiple_outputs_to_intermediate_node(self): + input_shape = (1, 8, 8, 3) + in_model = get_multiple_outputs_to_intermediate_node_model(input_shape) + _repr_dataset = functools.partial(representative_dataset, + input_shape=input_shape) + self._test_advanced_graph(in_model, _repr_dataset) diff --git a/tests/keras_tests/function_tests/test_model_gradients.py b/tests/keras_tests/function_tests/test_model_gradients.py deleted file mode 100644 index 05a7eb1f2..000000000 --- a/tests/keras_tests/function_tests/test_model_gradients.py +++ /dev/null @@ -1,207 +0,0 @@ -import functools - -import keras -import unittest - -from keras.layers import Dense -from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, Input, SeparableConv2D, Reshape -from tensorflow import initializers -import numpy as np -import tensorflow as tf - -import model_compression_toolkit.core.common.hessian as hessian_common -from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO -from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc - -import model_compression_toolkit as mct -from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_configs - -tp = mct.target_platform - - -def basic_derivative_model(input_shape): - inputs = Input(shape=input_shape) - outputs = 2 * inputs + 1 - return keras.Model(inputs=inputs, outputs=outputs) - - -def basic_model(input_shape): - random_uniform = initializers.random_uniform(0, 1) - inputs = Input(shape=input_shape) - x = Conv2D(2, 3, padding='same', name="conv2d")(inputs) - x_bn = BatchNormalization(gamma_initializer='random_normal', beta_initializer='random_normal', - moving_mean_initializer='random_normal', moving_variance_initializer=random_uniform, - name="bn1")(x) - outputs = ReLU()(x_bn) - return keras.Model(inputs=inputs, outputs=outputs) - - -def advenced_model(input_shape): - random_uniform = initializers.random_uniform(0, 1) - inputs = Input(shape=input_shape, name='input1') - x = Conv2D(2, 3, padding='same', name="conv2d_1")(inputs) - x_bn = BatchNormalization(gamma_initializer='random_normal', beta_initializer='random_normal', - moving_mean_initializer='random_normal', moving_variance_initializer=random_uniform, - name="bn1")(x) - x_relu = ReLU(name='relu1')(x_bn) - x_2 = Conv2D(2, 3, padding='same', name="conv2d_2")(x_relu) - x_bn2 = BatchNormalization(gamma_initializer='random_normal', beta_initializer='random_normal', - moving_mean_initializer='random_normal', moving_variance_initializer=random_uniform, - name='bn2')(x_2) - x_reshape = Reshape((-1,), name='reshape1')(x_bn2) - x_bn3 = BatchNormalization(gamma_initializer='random_normal', beta_initializer='random_normal', - moving_mean_initializer='random_normal', moving_variance_initializer=random_uniform, - name='bn3')( - x_reshape) - outputs = ReLU(name='relu2')(x_bn3) - return keras.Model(inputs=inputs, outputs=outputs) - - -def multiple_output_model(input_shape): - inputs = Input(shape=input_shape) - x = Dense(2)(inputs) - x = Conv2D(2, 4)(x) - x = BatchNormalization()(x) - out1 = ReLU(max_value=6.0)(x) - out2 = Conv2D(2, 4)(out1) - return keras.Model(inputs=inputs, outputs=[out1, out2]) - - -def inputs_as_list_model(input_shape): - input1 = Input(shape=input_shape) - input2 = Input(shape=input_shape) - x_stack = tf.stack([input1, input2]) - x_conv = Conv2D(2, 3, padding='same', name="conv2d")(x_stack) - x_bn = BatchNormalization()(x_conv) - outputs = ReLU()(x_bn) - return keras.Model(inputs=[input1, input2], outputs=outputs) - -def multiple_outputs_node_model(input_shape): - inputs = Input(shape=input_shape) - x_conv = Conv2D(2, 3, padding='same', name="conv2d")(inputs) - x_bn = BatchNormalization()(x_conv) - x_relu = ReLU()(x_bn) - x_split = tf.split(x_relu, num_or_size_splits=2, axis=-1) - outputs = x_split[0]+x_split[1] - return keras.Model(inputs=inputs, outputs=outputs) - - -def model_with_output_replacements(input_shape): - random_uniform = initializers.random_uniform(0, 1) - inputs = Input(shape=input_shape) - x = Conv2D(2, 3, padding='same', name="conv2d")(inputs) - x_bn = BatchNormalization(gamma_initializer='random_normal', beta_initializer='random_normal', - moving_mean_initializer='random_normal', moving_variance_initializer=random_uniform, - name="bn1")(x) - x_relu = ReLU()(x_bn) - x_soft = tf.nn.softmax(x_relu) - outputs = tf.math.argmax(x_soft) - - return keras.Model(inputs=inputs, outputs=outputs) - - -def representative_dataset(num_of_inputs=1): - yield [np.random.randn(1, 8, 8, 3).astype(np.float32)]*num_of_inputs - -def _get_normalized_hessian_trace_approx(graph, interest_points, keras_impl, alpha, num_of_inputs=1): - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=functools.partial(representative_dataset, num_of_inputs=num_of_inputs), - fw_impl=keras_impl) - x = [] - for interest_point in interest_points: - request = hessian_common.TraceHessianRequest(mode=hessian_common.HessianMode.ACTIVATION, - granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - target_node=interest_point) - hessian_data = hessian_service.fetch_hessian(request, 1) - hessian_data_per_image = hessian_data[0] - assert isinstance(hessian_data_per_image, list) - assert len(hessian_data_per_image) == 1 - x.append(hessian_data_per_image[0]) - x = hessian_common.hessian_utils.normalize_weights(x, alpha=alpha, outputs_indices=[len(interest_points) - 1]) - del hessian_service - return x - - -class TestModelGradients(unittest.TestCase): - # TODO: change tests to ignore the normalization and check - # closeness to ACTUAL hessian values on small trained models. - - def _run_model_grad_test(self, graph, keras_impl, output_indices=None, num_of_inputs=1): - sorted_graph_nodes = graph.get_topo_sorted_nodes() - interest_points = [n for n in sorted_graph_nodes] - all_output_indices = [len(interest_points) - 1] if output_indices is None else output_indices - x = _get_normalized_hessian_trace_approx(graph, interest_points, keras_impl, alpha=0.3, num_of_inputs=num_of_inputs) - - # Checking that the weights were computed and normalized correctly - # In rare occasions, the output tensor has all zeros, so the gradients for all interest points are zeros. - # This is a pathological case that is not possible in real networks, so we just extend the assertion to prevent - # the test from failing in this rare cases. - self.assertTrue(np.isclose(np.sum(x), 1) or all([y == 0 for i, y in enumerate(x) if i not in all_output_indices])) - - def test_jacobian_trace_calculation(self): - input_shape = (8, 8, 3) - in_model = basic_derivative_model(input_shape) - keras_impl = KerasImplementation() - graph = prepare_graph_with_configs(in_model, keras_impl, DEFAULT_KERAS_INFO, representative_dataset, generate_keras_tpc) - - sorted_graph_nodes = graph.get_topo_sorted_nodes() - interest_points = [n for n in sorted_graph_nodes] - - x = _get_normalized_hessian_trace_approx(graph, interest_points, keras_impl, alpha=0) - - # These are the expected values of the normalized gradients (gradients should be 2 and 1 - # with respect to input and mult layer, respectively) - self.assertTrue(np.isclose(x[0], np.float32(0.8), 1e-1)) - self.assertTrue(np.isclose(x[1], np.float32(0.2), 1e-1)) - self.assertTrue(np.isclose(x[2], np.float32(0.0))) - - y = _get_normalized_hessian_trace_approx(graph, interest_points, keras_impl, alpha=1) - - self.assertTrue(np.isclose(y[0], np.float32(0.0))) - self.assertTrue(np.isclose(y[1], np.float32(0.0))) - self.assertTrue(np.isclose(y[2], np.float32(1.0))) - - - def test_basic_model_grad(self): - input_shape = (8, 8, 3) - in_model = basic_model(input_shape) - keras_impl = KerasImplementation() - graph = prepare_graph_with_configs(in_model, keras_impl, DEFAULT_KERAS_INFO, representative_dataset, generate_keras_tpc) - - self._run_model_grad_test(graph, keras_impl) - - def test_advanced_model_grad(self): - input_shape = (8, 8, 3) - in_model = advenced_model(input_shape) - keras_impl = KerasImplementation() - graph = prepare_graph_with_configs(in_model, keras_impl, DEFAULT_KERAS_INFO, representative_dataset, generate_keras_tpc) - - self._run_model_grad_test(graph, keras_impl) - - def test_multiple_outputs_grad(self): - input_shape = (8, 8, 3) - in_model = multiple_output_model(input_shape) - keras_impl = KerasImplementation() - graph = prepare_graph_with_configs(in_model, keras_impl, DEFAULT_KERAS_INFO, representative_dataset, generate_keras_tpc) - - sorted_graph_nodes = graph.get_topo_sorted_nodes() - self._run_model_grad_test(graph, keras_impl, output_indices=[len(sorted_graph_nodes) - 1, - len(sorted_graph_nodes) - 2]) - - - def test_inputs_as_list_model_grad(self): - input_shape = (8, 8, 3) - in_model = inputs_as_list_model(input_shape) - keras_impl = KerasImplementation() - graph = prepare_graph_with_configs(in_model, keras_impl, DEFAULT_KERAS_INFO, representative_dataset, generate_keras_tpc) - self._run_model_grad_test(graph, keras_impl, num_of_inputs=2) - - def test_multiple_outputs_node(self): - input_shape = (8, 8, 3) - in_model = multiple_outputs_node_model(input_shape) - keras_impl = KerasImplementation() - graph = prepare_graph_with_configs(in_model, keras_impl, DEFAULT_KERAS_INFO, representative_dataset, - generate_keras_tpc) - self._run_model_grad_test(graph, keras_impl) - diff --git a/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py b/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py new file mode 100644 index 000000000..fc4d85f60 --- /dev/null +++ b/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py @@ -0,0 +1,76 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import unittest +import numpy as np +import tensorflow as tf + + +from model_compression_toolkit.core import MixedPrecisionQuantizationConfigV2 +from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO +from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc +from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_quantization_parameters +import model_compression_toolkit.core.common.hessian as hess + +keras = tf.keras +layers = keras.layers + + +def argmax_output_model(input_shape): + inputs = layers.Input(shape=input_shape) + x = layers.Conv2D(3, 3)(inputs) + x = layers.BatchNormalization()(x) + x = layers.Conv2D(3, 3)(x) + x = layers.ReLU()(x) + outputs = tf.argmax(x, axis=-1) + model = keras.Model(inputs=inputs, outputs=outputs) + return model + + +def representative_dataset(): + yield [np.random.randn(1, 8, 8, 3).astype(np.float32)] + + +class TestSensitivityEvalWithNonSupportedOutputNodes(unittest.TestCase): + + def verify_test_for_model(self, model): + keras_impl = KerasImplementation() + graph = prepare_graph_with_quantization_parameters(model, + keras_impl, + DEFAULT_KERAS_INFO, + representative_dataset, + generate_keras_tpc, + input_shape=(1, 8, 8, 3), + mixed_precision_enabled=True) + + hessian_info_service = hess.HessianInfoService(graph=graph, + representative_dataset=representative_dataset, + fw_impl=keras_impl) + + se = keras_impl.get_sensitivity_evaluator(graph, + MixedPrecisionQuantizationConfigV2(use_grad_based_weights=True), + representative_dataset, + DEFAULT_KERAS_INFO, + hessian_info_service=hessian_info_service) + + def test_not_supported_output_argmax(self): + model = argmax_output_model((8, 8, 3)) + with self.assertRaises(Exception) as e: + self.verify_test_for_model(model) + self.assertTrue("All graph outputs should support Hessian computation" in str(e.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py b/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py index efe0dbc91..b742d4bcc 100644 --- a/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py +++ b/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py @@ -79,7 +79,8 @@ def softmax_model(input_shape): x = tf.keras.layers.BatchNormalization()(x) x = tf.keras.layers.Softmax(axis=2)(x) x = tf.keras.layers.Dense(32)(x) - outputs = tf.nn.softmax(x, axis=-1) + x = tf.nn.softmax(x, axis=-1) + outputs = tf.keras.layers.Reshape((-1,))(x) model = tf.keras.Model(inputs=inputs, outputs=outputs) return model @@ -104,7 +105,9 @@ def test_nonfiltered_interest_points_set(self): ip_nodes = list(filter(lambda n: KerasImplementation().count_node_for_mixed_precision_interest_points(n), sorted_nodes)) - self.assertTrue(len(ips) == len(ip_nodes), + # Note that the model's output node is shouldn't be included in the sensitivity evaluation list of + # interest points (it is included in a separate list) + self.assertTrue(len(ips) == len(ip_nodes) - 1, f"Filtered interest points list should include exactly {len(ip_nodes)}, but it" f"includes {len(ips)}") diff --git a/tests/pytorch_tests/function_tests/model_gradients_test.py b/tests/pytorch_tests/function_tests/model_gradients_test.py deleted file mode 100644 index 02b11d73a..000000000 --- a/tests/pytorch_tests/function_tests/model_gradients_test.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import torch -from torch.nn import Conv2d, BatchNorm2d, ReLU, Linear - -from model_compression_toolkit.core.pytorch.utils import to_torch_tensor -import numpy as np - -from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO -from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc -from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_configs -from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest -import model_compression_toolkit.core.common.hessian as hessian_common - -""" -This test checks the BatchNorm info collection. -""" - - -def bn_weight_change(bn: torch.nn.Module): - bw_shape = bn.weight.shape - delattr(bn, 'weight') - delattr(bn, 'bias') - delattr(bn, 'running_var') - delattr(bn, 'running_mean') - bn.register_buffer('weight', torch.rand(bw_shape)) - bn.register_buffer('bias', torch.rand(bw_shape)) - bn.register_buffer('running_var', torch.abs(torch.rand(bw_shape))) - bn.register_buffer('running_mean', torch.rand(bw_shape)) - return bn - - -class basic_derivative_model(torch.nn.Module): - def __init__(self): - super(basic_derivative_model, self).__init__() - - def forward(self, inp): - x = torch.mul(inp, 2) - x = x + 1 - return x - - -class basic_model(torch.nn.Module): - def __init__(self): - super(basic_model, self).__init__() - self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) - self.bn = BatchNorm2d(3) - self.relu = ReLU() - - def forward(self, inp): - x = self.conv1(inp) - x = self.bn(x) - x = self.relu(x) - return x - - -class advanced_model(torch.nn.Module): - def __init__(self): - super(advanced_model, self).__init__() - self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) - self.bn1 = BatchNorm2d(3) - self.relu1 = ReLU() - - self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1) - self.bn2 = BatchNorm2d(3) - self.relu2 = ReLU() - - self.bn3 = BatchNorm2d(3) - - def forward(self, inp): - x = self.conv1(inp) - x = self.bn1(x) - x = self.relu1(x) - x = self.conv2(x) - x = self.bn2(x) - x = self.relu2(x) - - x = torch.reshape(x, [-1]) - x = self.bn3(x) - return x - - -class multiple_output_model(torch.nn.Module): - def __init__(self): - super(multiple_output_model, self).__init__() - self.linear = Linear(32, 32) - self.conv1 = Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1)) - self.bn = BatchNorm2d(3) - self.relu = ReLU() - self.conv2 = Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1)) - - def forward(self, inp): - x = self.linear(inp) - x = self.conv1(x) - x = self.bn(x) - x1 = self.relu(x) - x2 = self.conv2(x1) - return x1, x2 - - -class model_with_output_replacements(torch.nn.Module): - def __init__(self): - super(model_with_output_replacements, self).__init__() - - self.conv = Conv2d(3, 3, kernel_size=1, stride=1) - self.bn = BatchNorm2d(3) - self.relu = ReLU() - - def forward(self, inp): - x = self.conv(inp) - x = self.bn(x) - x = self.relu(x) - x = torch.argmax(x) - return x - - -class node_with_multiple_outputs_model(torch.nn.Module): - def __init__(self): - super(node_with_multiple_outputs_model, self).__init__() - self.conv1 = Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1)) - self.bn = BatchNorm2d(3) - self.relu = ReLU() - - def forward(self, inp): - x = self.conv1(inp) - x = self.bn(x) - x = self.relu(x) - y, z, w = torch.split(x, split_size_or_sections=1, dim=1) - return y, z, w - - -class non_differentiable_node_model(torch.nn.Module): - def __init__(self): - super(non_differentiable_node_model, self).__init__() - self.conv1 = Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1)) - self.bn = BatchNorm2d(3) - self.relu = ReLU() - - def forward(self, inp): - x = self.conv1(inp) - x = self.bn(x) - x = self.relu(x) - y = torch.argmax(x) # this is a dummy operation just to cover edge case in the gradients computation - - return x * y - - -def generate_inputs(inputs_shape): - inputs = [] - for in_shape in inputs_shape: - t = torch.randn(*in_shape) - t.requires_grad_() - inputs.append(t) - inputs = to_torch_tensor(inputs) - return inputs - - -class ModelGradientsCalculationTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 1 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 32, 32]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self, n_iters=1): - input_shapes = self.create_inputs_shape() - for _ in range(n_iters): - yield self.generate_inputs(input_shapes) - - def run_test(self, seed=0): - model_float = basic_derivative_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - - ipts = [n for n in graph.get_topo_sorted_nodes()] - model_grads = _get_normalized_hessian_trace_approx(representative_data_gen=self.representative_data_gen, - graph=graph, - interest_points=ipts, - pytorch_impl=pytorch_impl, - alpha=0) - - self.unit_test.assertTrue(np.isclose(model_grads[0], 0.8, 1e-1)) - self.unit_test.assertTrue(np.isclose(model_grads[1], 0.2, 1e-1)) - self.unit_test.assertTrue(model_grads[2] == 0.0) - - - -def _get_normalized_hessian_trace_approx(representative_data_gen, - graph, - interest_points, - pytorch_impl, - alpha): - - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=representative_data_gen, - fw_impl=pytorch_impl) - x = [] - for interest_point in interest_points: - request = hessian_common.TraceHessianRequest(mode=hessian_common.HessianMode.ACTIVATION, - granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - target_node=interest_point) - hessian_data = hessian_service.fetch_hessian(request, 1) - hessian_data_per_image = hessian_data[0] - assert isinstance(hessian_data_per_image, list) - assert len(hessian_data_per_image) == 1 - x.append(hessian_data_per_image[0]) - x = hessian_common.hessian_utils.normalize_weights(x, alpha=alpha, outputs_indices=[len(interest_points) - 1]) - return x - - -class ModelGradientsBasicModelTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 1 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 32, 32]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - - def run_test(self, seed=0): - model_float = basic_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - - ipts = [n for n in graph.get_topo_sorted_nodes()] - model_grads = _get_normalized_hessian_trace_approx(representative_data_gen=self.representative_data_gen, - graph=graph, - interest_points=ipts, - pytorch_impl=pytorch_impl, - alpha=0.3) - - # Checking that the weights where computed and normalized correctly - self.unit_test.assertTrue(np.isclose(np.sum(model_grads), 1)) - - -class ModelGradientsAdvancedModelTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 1 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 32, 32]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - def run_test(self, seed=0): - model_float = basic_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - - ipts = [n for n in graph.get_topo_sorted_nodes()] - model_grads = _get_normalized_hessian_trace_approx(representative_data_gen=self.representative_data_gen, - graph=graph, - interest_points=ipts, - pytorch_impl=pytorch_impl, - alpha=0.3) - - # Checking that the weights where computed and normalized correctly - self.unit_test.assertTrue(np.isclose(np.sum(model_grads), 1)) - - -class ModelGradientsMultipleOutputsTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 1 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 32, 32]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - def run_test(self, seed=0): - model_float = multiple_output_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - - ipts = [n for n in graph.get_topo_sorted_nodes()] - model_grads = _get_normalized_hessian_trace_approx(representative_data_gen=self.representative_data_gen, - graph=graph, - interest_points=ipts, - pytorch_impl=pytorch_impl, - alpha=0.3) - - # Checking that the weights where computed and normalized correctly - self.unit_test.assertTrue(np.isclose(np.sum(model_grads), 1)) - - -class ModelGradientsOutputReplacementTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 1 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 32, 32]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - def run_test(self, seed=0): - model_float = model_with_output_replacements() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - ipts = [n for n in graph.get_topo_sorted_nodes()] - with self.unit_test.assertRaises(Exception) as e: - _get_normalized_hessian_trace_approx(representative_data_gen=self.representative_data_gen, - graph=graph, - interest_points=ipts, - pytorch_impl=pytorch_impl, - alpha=0.3) - self.unit_test.assertTrue("All graph outputs should support metric outputs" in str(e.exception)) - - -class ModelGradientsMultipleOutputsModelTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 1 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 32, 32]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - def run_test(self, seed=0, **kwargs): - model_float = node_with_multiple_outputs_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - - ipts = [n for n in graph.get_topo_sorted_nodes()] - model_grads = _get_normalized_hessian_trace_approx(representative_data_gen=self.representative_data_gen, - graph=graph, - interest_points=ipts, - pytorch_impl=pytorch_impl, - alpha=0.3) - - # Checking that the weights where computed and normalized correctly - self.unit_test.assertTrue(np.isclose(np.sum(model_grads), 1)) - - -class ModelGradientsNonDifferentiableNodeModelTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 1 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 32, 32]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - def run_test(self, seed=0, **kwargs): - model_float = non_differentiable_node_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - - ipts = [n for n in graph.get_topo_sorted_nodes()] - model_grads = _get_normalized_hessian_trace_approx(representative_data_gen=self.representative_data_gen, - graph=graph, - interest_points=ipts, - pytorch_impl=pytorch_impl, - alpha=0.3) - - # Checking that the weights where computed and normalized correctly - self.unit_test.assertTrue(np.isclose(np.sum(model_grads), 1)) - - -class ModelGradientsSinglePointTest(ModelGradientsBasicModelTest): - def __init__(self, unit_test): - super().__init__(unit_test) - - def run_test(self, seed=0): - model_float = basic_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - - ipts = [graph.get_topo_sorted_nodes()[-1]] - model_grads = _get_normalized_hessian_trace_approx(representative_data_gen=self.representative_data_gen, - graph=graph, - interest_points=ipts, - pytorch_impl=pytorch_impl, - alpha=0.3) - - self.unit_test.assertTrue(len(model_grads) == 1 and model_grads[0] == 1.0) \ No newline at end of file diff --git a/tests/pytorch_tests/function_tests/test_function_runner.py b/tests/pytorch_tests/function_tests/test_function_runner.py index efd17dabc..ecfc7ab50 100644 --- a/tests/pytorch_tests/function_tests/test_function_runner.py +++ b/tests/pytorch_tests/function_tests/test_function_runner.py @@ -24,16 +24,15 @@ TestKPIDataBasicPartialBitwidth, TestKPIDataComplexPartialBitwidth, TestKPIDataComplesAllBitwidth from tests.pytorch_tests.function_tests.layer_fusing_test import LayerFusingTest1, LayerFusingTest2, LayerFusingTest3, \ LayerFusingTest4 -from tests.pytorch_tests.function_tests.model_gradients_test import ModelGradientsBasicModelTest, \ - ModelGradientsCalculationTest, ModelGradientsAdvancedModelTest, ModelGradientsOutputReplacementTest, \ - ModelGradientsMultipleOutputsModelTest, ModelGradientsNonDifferentiableNodeModelTest, \ - ModelGradientsMultipleOutputsTest, ModelGradientsSinglePointTest from tests.pytorch_tests.function_tests.set_layer_to_bitwidth_test import TestSetLayerToBitwidthWeights, \ TestSetLayerToBitwidthActivation -from tests.pytorch_tests.function_tests.test_sensitivity_eval_output_replacement import \ - TestSensitivityEvalWithArgmaxOutputReplacementNodes, TestSensitivityEvalWithSoftmaxOutputReplacementNodes -from tests.pytorch_tests.function_tests.test_hessian_info_weights import WeightsHessianTraceBasicModelTest, WeightsHessianTraceAdvanceModelTest, \ -WeightsHessianTraceMultipleOutputsModelTest, WeightsHessianTraceReuseModelTest +from tests.pytorch_tests.function_tests.test_sensitivity_eval_non_supported_output import \ + TestSensitivityEvalWithArgmaxNode +from tests.pytorch_tests.function_tests.test_hessian_info_calculator import WeightsHessianTraceBasicModelTest, \ + WeightsHessianTraceAdvanceModelTest, \ + WeightsHessianTraceMultipleOutputsModelTest, WeightsHessianTraceReuseModelTest, \ + ActivationHessianTraceBasicModelTest, ActivationHessianTraceAdvanceModelTest, \ + ActivationHessianTraceMultipleOutputsModelTest, ActivationHessianTraceReuseModelTest class FunctionTestRunner(unittest.TestCase): @@ -101,22 +100,18 @@ def test_kpi_data_complex_partial(self): """ TestKPIDataComplexPartialBitwidth(self).run_test() - def test_model_gradients(self): + def test_activation_hessian_trace(self): """ - This test checks the Model Gradients Pytorch computation. + This test checks the activation hessian trace approximation in Pytorch. """ - ModelGradientsBasicModelTest(self).run_test() - ModelGradientsCalculationTest(self).run_test() - ModelGradientsAdvancedModelTest(self).run_test() - ModelGradientsMultipleOutputsTest(self).run_test() - ModelGradientsOutputReplacementTest(self).run_test() - ModelGradientsMultipleOutputsModelTest(self).run_test() - ModelGradientsNonDifferentiableNodeModelTest(self).run_test() - ModelGradientsSinglePointTest(self).run_test() + ActivationHessianTraceBasicModelTest(self).run_test() + ActivationHessianTraceAdvanceModelTest(self).run_test() + ActivationHessianTraceMultipleOutputsModelTest(self).run_test() + ActivationHessianTraceReuseModelTest(self).run_test() def test_weights_hessian_trace(self): """ - This test checks the weighes hessian trace approximation in Pytorch. + This test checks the weights hessian trace approximation in Pytorch. """ WeightsHessianTraceBasicModelTest(self).run_test() WeightsHessianTraceAdvanceModelTest(self).run_test() @@ -140,12 +135,11 @@ def test_mixed_precision_set_bitwidth(self): TestSetLayerToBitwidthWeights(self).run_test() TestSetLayerToBitwidthActivation(self).run_test() - def test_sensitivity_eval_outputs_replacement(self): + def test_sensitivity_eval_not_supported_output(self): """ - This test checks the functionality output replacement nodes in sensitivity evaluation for mixed precision.. + This test verifies failure on non-supported output nodes in mixed precision with Hessian-based scores. """ - TestSensitivityEvalWithArgmaxOutputReplacementNodes(self).run_test() - TestSensitivityEvalWithSoftmaxOutputReplacementNodes(self).run_test() + TestSensitivityEvalWithArgmaxNode(self).run_test() def test_get_gptq_config(self): """ diff --git a/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py b/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py new file mode 100644 index 000000000..b4edf1214 --- /dev/null +++ b/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py @@ -0,0 +1,361 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch +from torch.nn import Conv2d, BatchNorm2d, ReLU, Linear, Hardswish + +from model_compression_toolkit.core.pytorch.utils import to_torch_tensor +import numpy as np + +from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO +from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc +from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_configs +from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest +import model_compression_toolkit.core.common.hessian as hessian_common + +""" +This test checks the model gradients computation +""" + + +class basic_model(torch.nn.Module): + def __init__(self): + super(basic_model, self).__init__() + self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) + self.bn = BatchNorm2d(3) + self.relu = ReLU() + + def forward(self, inp): + x = self.conv1(inp) + x = self.bn(x) + x = self.relu(x) + return x + + +class advanced_model(torch.nn.Module): + def __init__(self): + super(advanced_model, self).__init__() + self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) + self.bn1 = BatchNorm2d(3) + self.relu1 = ReLU() + self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1) + self.bn2 = BatchNorm2d(3) + self.relu2 = ReLU() + self.dense = Linear(8, 7) + + def forward(self, inp): + x = self.conv1(inp) + x = self.bn1(x) + x = self.relu1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + x = self.dense(x) + return x + + +class multiple_outputs_model(torch.nn.Module): + def __init__(self): + super(multiple_outputs_model, self).__init__() + self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) + self.bn1 = BatchNorm2d(3) + self.relu1 = ReLU() + self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1) + self.bn2 = BatchNorm2d(3) + self.hswish = Hardswish() + self.dense = Linear(8, 7) + + def forward(self, inp): + x = self.conv1(inp) + x = self.bn1(x) + x1 = self.relu1(x) + x2 = self.conv2(x1) + x2 = self.bn2(x2) + x3 = self.hswish(x2) + x3 = self.dense(x3) + return x1, x2, x3 + + +class reused_model(torch.nn.Module): + def __init__(self): + super(reused_model, self).__init__() + self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) + self.bn1 = BatchNorm2d(3) + self.relu = ReLU() + + def forward(self, inp): + x = self.conv1(inp) + x1 = self.bn1(x) + x1 = self.relu(x1) + x_split = torch.split(x1, split_size_or_sections=4, dim=-1) + x1 = self.conv1(x_split[0]) + x2 = x_split[1] + x1 = self.relu(x1) + y = torch.concat([x1, x2], dim=-1) + return y + + +def generate_inputs(inputs_shape): + inputs = [] + for in_shape in inputs_shape: + t = torch.randn(*in_shape) + t.requires_grad_() + inputs.append(t) + inputs = to_torch_tensor(inputs) + return inputs + + +def get_expected_shape(t_shape, granularity): + if granularity == hessian_common.HessianInfoGranularity.PER_ELEMENT: + return t_shape + elif granularity == hessian_common.HessianInfoGranularity.PER_TENSOR: + return (1,) + else: + return (t_shape[0],) + + +class BaseHessianTraceBasicModelTest(BasePytorchTest): + + def __init__(self, unit_test): + super().__init__(unit_test) + self.val_batch_size = 1 + + def create_inputs_shape(self): + return [[self.val_batch_size, 3, 8, 8]] + + @staticmethod + def generate_inputs(input_shapes): + return generate_inputs(input_shapes) + + def representative_data_gen(self): + input_shapes = self.create_inputs_shape() + yield self.generate_inputs(input_shapes) + + def test_hessian_trace_approx(self, + hessian_service, + interest_point, + mode, + granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, + num_scores=1): + request = hessian_common.TraceHessianRequest(mode=mode, + granularity=granularity, + target_node=interest_point) + expected_shape = get_expected_shape(interest_point.weights['weight'].shape, granularity) + info = hessian_service.fetch_hessian(request, num_scores) + assert len(info) == num_scores, f"fetched {num_scores} score but {len(info)} scores were fetched" + score = np.mean(np.stack(info), axis=0) + + self.unit_test.assertTrue(isinstance(info, list)) + self.unit_test.assertTrue(len(info) == num_scores, + f"fetched {num_scores} score but {len(info)} scores were fetched") + self.unit_test.assertTrue(score.shape == expected_shape, + f"Tensor shape is expected to be {expected_shape} but has shape {score.shape}") + + def _setup(self): + model_float = basic_model() + pytorch_impl = PytorchImplementation() + graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, + self.representative_data_gen, generate_pytorch_tpc) + + return graph, pytorch_impl + + +class WeightsHessianTraceBasicModelTest(BaseHessianTraceBasicModelTest): + def __init__(self, unit_test): + super().__init__(unit_test) + self.val_batch_size = 1 + + def run_test(self, seed=0): + graph, pytorch_impl = self._setup() + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=self.representative_data_gen, + fw_impl=pytorch_impl) + ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] + for ipt in ipts: + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, + mode=hessian_common.HessianMode.WEIGHTS) + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS) + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, + mode=hessian_common.HessianMode.WEIGHTS) + + +class WeightsHessianTraceAdvanceModelTest(BaseHessianTraceBasicModelTest): + def __init__(self, unit_test): + super().__init__(unit_test) + self.val_batch_size = 2 + + def run_test(self, seed=0): + graph, pytorch_impl = self._setup() + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=self.representative_data_gen, + fw_impl=pytorch_impl) + ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] + for ipt in ipts: + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=1, + granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, + mode=hessian_common.HessianMode.WEIGHTS) + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=2, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS) + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=3, + granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, + mode=hessian_common.HessianMode.WEIGHTS) + + +class WeightsHessianTraceMultipleOutputsModelTest(BaseHessianTraceBasicModelTest): + def __init__(self, unit_test): + super().__init__(unit_test) + self.val_batch_size = 1 + + def run_test(self, seed=0): + graph, pytorch_impl = self._setup() + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=self.representative_data_gen, + fw_impl=pytorch_impl) + ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] + for ipt in ipts: + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=1, + granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, + mode=hessian_common.HessianMode.WEIGHTS) + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=2, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS) + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=3, + granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, + mode=hessian_common.HessianMode.WEIGHTS) + + +class WeightsHessianTraceReuseModelTest(BaseHessianTraceBasicModelTest): + def __init__(self, unit_test): + super().__init__(unit_test) + self.val_batch_size = 1 + + def run_test(self, seed=0): + graph, pytorch_impl = self._setup() + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=self.representative_data_gen, + fw_impl=pytorch_impl) + ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] + for ipt in ipts: + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=1, + granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, + mode=hessian_common.HessianMode.WEIGHTS) + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=2, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.WEIGHTS) + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=3, + granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, + mode=hessian_common.HessianMode.WEIGHTS) + + +class ActivationHessianTraceBasicModelTest(BaseHessianTraceBasicModelTest): + def __init__(self, unit_test): + super().__init__(unit_test) + self.val_batch_size = 1 + + def run_test(self, seed=0): + graph, pytorch_impl = self._setup() + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=self.representative_data_gen, + fw_impl=pytorch_impl) + ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] + for ipt in ipts: + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION) + + +class ActivationHessianTraceAdvanceModelTest(BaseHessianTraceBasicModelTest): + def __init__(self, unit_test): + super().__init__(unit_test) + self.val_batch_size = 2 + + def run_test(self, seed=0): + graph, pytorch_impl = self._setup() + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=self.representative_data_gen, + fw_impl=pytorch_impl) + ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] + for ipt in ipts: + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=2, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION) + + +class ActivationHessianTraceMultipleOutputsModelTest(BaseHessianTraceBasicModelTest): + def __init__(self, unit_test): + super().__init__(unit_test) + self.val_batch_size = 1 + + def run_test(self, seed=0): + graph, pytorch_impl = self._setup() + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=self.representative_data_gen, + fw_impl=pytorch_impl) + ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] + for ipt in ipts: + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=2, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION) + + +class ActivationHessianTraceReuseModelTest(BaseHessianTraceBasicModelTest): + def __init__(self, unit_test): + super().__init__(unit_test) + self.val_batch_size = 1 + + def run_test(self, seed=0): + graph, pytorch_impl = self._setup() + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset=self.representative_data_gen, + fw_impl=pytorch_impl) + ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] + for ipt in ipts: + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + num_scores=2, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION) diff --git a/tests/pytorch_tests/function_tests/test_hessian_info_weights.py b/tests/pytorch_tests/function_tests/test_hessian_info_weights.py deleted file mode 100644 index f68ee18a1..000000000 --- a/tests/pytorch_tests/function_tests/test_hessian_info_weights.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import torch -from torch.nn import Conv2d, BatchNorm2d, ReLU, Linear, Hardswish - -from model_compression_toolkit.core.pytorch.utils import to_torch_tensor -import numpy as np - -from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO -from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc -from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_configs -from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest -import model_compression_toolkit.core.common.hessian as hessian_common - -""" -This test checks the model gradients computation -""" - - -class basic_model(torch.nn.Module): - def __init__(self): - super(basic_model, self).__init__() - self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) - self.bn = BatchNorm2d(3) - self.relu = ReLU() - - def forward(self, inp): - x = self.conv1(inp) - x = self.bn(x) - x = self.relu(x) - return x - - -class advanced_model(torch.nn.Module): - def __init__(self): - super(advanced_model, self).__init__() - self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) - self.bn1 = BatchNorm2d(3) - self.relu1 = ReLU() - self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1) - self.bn2 = BatchNorm2d(3) - self.relu2 = ReLU() - self.dense = Linear(8, 7) - - def forward(self, inp): - x = self.conv1(inp) - x = self.bn1(x) - x = self.relu1(x) - x = self.conv2(x) - x = self.bn2(x) - x = self.relu2(x) - x = self.dense(x) - return x - - -class multiple_outputs_model(torch.nn.Module): - def __init__(self): - super(multiple_outputs_model, self).__init__() - self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) - self.bn1 = BatchNorm2d(3) - self.relu1 = ReLU() - self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1) - self.bn2 = BatchNorm2d(3) - self.hswish = Hardswish() - self.dense = Linear(8, 7) - - def forward(self, inp): - x = self.conv1(inp) - x = self.bn1(x) - x1 = self.relu1(x) - x2 = self.conv2(x1) - x2 = self.bn2(x2) - x3 = self.hswish(x2) - x3 = self.dense(x3) - return x1, x2, x3 - - -class reused_model(torch.nn.Module): - def __init__(self): - super(reused_model, self).__init__() - self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) - self.bn1 = BatchNorm2d(3) - self.relu = ReLU() - - def forward(self, inp): - x = self.conv1(inp) - x1 = self.bn1(x) - x1 = self.relu(x1) - x_split = torch.split(x1, split_size_or_sections=4, dim=-1) - x1 = self.conv1(x_split[0]) - x2 = x_split[1] - x1 = self.relu(x1) - y = torch.concat([x1, x2], dim=-1) - return y - - -def generate_inputs(inputs_shape): - inputs = [] - for in_shape in inputs_shape: - t = torch.randn(*in_shape) - t.requires_grad_() - inputs.append(t) - inputs = to_torch_tensor(inputs) - return inputs - - -def get_expected_shape(weights_shape, granularity): - if granularity==hessian_common.HessianInfoGranularity.PER_ELEMENT: - return weights_shape - elif granularity==hessian_common.HessianInfoGranularity.PER_TENSOR: - return (1,) - else: - return (weights_shape[0],) - - -def test_weights_hessian_trace_approx(hessian_service, - interest_point, - granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, - num_scores=1): - request = hessian_common.TraceHessianRequest(mode=hessian_common.HessianMode.WEIGHTS, - granularity=granularity, - target_node=interest_point) - expected_shape = get_expected_shape(interest_point.weights['weight'].shape, granularity) - info = hessian_service.fetch_hessian(request, num_scores) - score = info[0] - assert isinstance(info, list) - assert len(info) == num_scores, f"fetched {num_scores} score but {len(info)} scores were fetched" - assert isinstance(score, np.ndarray), f"scores expected to be a numpy array but is {type(score)}" - assert score.shape == expected_shape, f"Tensor shape is expected to be {expected_shape} but has shape {score.shape}" - - -class WeightsHessianTraceBasicModelTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 1 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 8, 8]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - def run_test(self, seed=0): - model_float = basic_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, - fw_impl=pytorch_impl) - ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights)>0] - for ipt in ipts: - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL) - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - granularity=hessian_common.HessianInfoGranularity.PER_TENSOR) - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT) - - -class WeightsHessianTraceAdvanceModelTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 2 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 8, 8]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - def run_test(self, seed=0): - model_float = advanced_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, - fw_impl=pytorch_impl) - ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights)>0] - for ipt in ipts: - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - num_scores=1, - granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL) - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - num_scores=2, - granularity=hessian_common.HessianInfoGranularity.PER_TENSOR) - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - num_scores=3, - granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT) - - -class WeightsHessianTraceMultipleOutputsModelTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 1 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 8, 8]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - def run_test(self, seed=0): - model_float = multiple_outputs_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, - fw_impl=pytorch_impl) - ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights)>0] - for ipt in ipts: - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - num_scores=1, - granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL) - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - num_scores=2, - granularity=hessian_common.HessianInfoGranularity.PER_TENSOR) - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - num_scores=3, - granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT) - - -class WeightsHessianTraceReuseModelTest(BasePytorchTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.val_batch_size = 1 - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 8, 8]] - - @staticmethod - def generate_inputs(input_shapes): - return generate_inputs(input_shapes) - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - def run_test(self, seed=0): - model_float = reused_model() - pytorch_impl = PytorchImplementation() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, generate_pytorch_tpc) - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, - fw_impl=pytorch_impl) - ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights)>0] - for ipt in ipts: - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - num_scores=1, - granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL) - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - num_scores=2, - granularity=hessian_common.HessianInfoGranularity.PER_TENSOR) - test_weights_hessian_trace_approx(hessian_service, - interest_point=ipt, - num_scores=3, - granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT) - - diff --git a/tests/pytorch_tests/function_tests/test_sensitivity_eval_output_replacement.py b/tests/pytorch_tests/function_tests/test_sensitivity_eval_non_supported_output.py similarity index 67% rename from tests/pytorch_tests/function_tests/test_sensitivity_eval_output_replacement.py rename to tests/pytorch_tests/function_tests/test_sensitivity_eval_non_supported_output.py index 42e13b599..7746e4b6d 100644 --- a/tests/pytorch_tests/function_tests/test_sensitivity_eval_output_replacement.py +++ b/tests/pytorch_tests/function_tests/test_sensitivity_eval_non_supported_output.py @@ -40,24 +40,7 @@ def forward(self, inp): return output -class softmax_output_model(torch.nn.Module): - def __init__(self): - super(softmax_output_model, self).__init__() - self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=(3, 3)) - self.bn1 = torch.nn.BatchNorm2d(3) - self.conv2 = torch.nn.Conv2d(3, 4, kernel_size=(5, 5)) - self.relu = torch.nn.ReLU() - - def forward(self, inp): - x = self.conv1(inp) - x = self.bn1(x) - x = self.conv2(x) - x = self.relu(x) - output = torch.softmax(x, dim=-1) - return output - - -class TestSensitivityEvalWithOutputReplacementBase(BasePytorchTest): +class TestSensitivityEvalWithNonSupportedOutputBase(BasePytorchTest): def create_inputs_shape(self): return [[1, 3, 16, 16]] @@ -88,18 +71,8 @@ def verify_test_for_model(self, model): DEFAULT_PYTORCH_INFO, hessian_info_service=hessian_info_service) - # If the output replacement nodes for MP sensitivity evaluation has been computed correctly then the ReLU layer - # should be added to the interest points and included in the output nodes list for metric computation purposes. - relu_node = graph.get_topo_sorted_nodes()[-2] - self.unit_test.assertTrue(relu_node.type == torch.nn.ReLU) - self.unit_test.assertIn(relu_node, se.interest_points) - self.unit_test.assertEqual(len(se.outputs_replacement_nodes), 1) - self.unit_test.assertIn(relu_node, se.outputs_replacement_nodes) - self.unit_test.assertEqual(se.output_nodes_indices, [2, 3]) - - -class TestSensitivityEvalWithArgmaxOutputReplacementNodes(TestSensitivityEvalWithOutputReplacementBase): +class TestSensitivityEvalWithArgmaxNode(TestSensitivityEvalWithNonSupportedOutputBase): def __init__(self, unit_test): super().__init__(unit_test) @@ -108,19 +81,4 @@ def run_test(self, seed=0, **kwargs): model = argmax_output_model() with self.unit_test.assertRaises(Exception) as e: self.verify_test_for_model(model) - self.unit_test.assertTrue("All graph outputs should support metric outputs" in str(e.exception)) - - - - -class TestSensitivityEvalWithSoftmaxOutputReplacementNodes(TestSensitivityEvalWithOutputReplacementBase): - - def __init__(self, unit_test): - super().__init__(unit_test) - - def run_test(self, seed=0, **kwargs): - model = softmax_output_model() - with self.unit_test.assertRaises(Exception) as e: - self.verify_test_for_model(model) - self.unit_test.assertTrue("All graph outputs should support metric outputs" in str(e.exception)) - + self.unit_test.assertTrue("All graph outputs should support Hessian computation" in str(e.exception)) diff --git a/tests/test_suite.py b/tests/test_suite.py index 371ebe9b7..47213fcc0 100644 --- a/tests/test_suite.py +++ b/tests/test_suite.py @@ -36,7 +36,8 @@ "torchvision") is not None if found_tf: - from tests.keras_tests.function_tests.test_hessian_info_calculator_weights import TestHessianInfoCalculatorWeights + from tests.keras_tests.function_tests.test_hessian_info_calculator import TestHessianInfoCalculatorWeights, \ + TestHessianInfoCalculatorActivation from tests.keras_tests.function_tests.test_hessian_service import TestHessianService from tests.keras_tests.feature_networks_tests.test_features_runner import FeatureNetworkTest from tests.keras_tests.function_tests.test_quantization_configurations import TestQuantizationConfigurations @@ -61,7 +62,8 @@ from tests.keras_tests.function_tests.test_activation_weights_composition_substitution import \ TestActivationWeightsComposition from tests.keras_tests.function_tests.test_graph_max_cut import TestGraphMaxCut - from tests.keras_tests.function_tests.test_model_gradients import TestModelGradients + from tests.keras_tests.function_tests.test_sensitivity_eval_non_suppoerted_output import \ + TestSensitivityEvalWithNonSupportedOutputNodes from tests.keras_tests.function_tests.test_set_layer_to_bitwidth import TestKerasSetLayerToBitwidth from tests.keras_tests.function_tests.test_export_keras_fully_quantized_model import TestKerasFakeQuantExporter from tests.keras_tests.function_tests.test_kpi_data import TestKPIData @@ -107,6 +109,7 @@ # Add TF tests only if tensorflow is installed if found_tf: suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestHessianInfoCalculatorWeights)) + suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestHessianInfoCalculatorActivation)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestHessianService)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestGPTQModelBuilderWithActivationHolder)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(ExporterTestsRunner)) @@ -127,9 +130,9 @@ suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestKerasTPModel)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestWeightsActivationSplit)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestActivationWeightsComposition)) - suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestModelGradients)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestGraphMaxCut)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestKerasSetLayerToBitwidth)) + suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestSensitivityEvalWithNonSupportedOutputNodes)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestKerasFakeQuantExporter)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestKPIData)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestFileLogger))