Skip to content

Commit

Permalink
Remove output replacement mechanism from Hessian computation (#850)
Browse files Browse the repository at this point in the history
Remove output nodes replacement mechanism for non-supported outputs in Hessian computation.
Instead, we inform the user that the Hessian-based information is not available for models that have unsupported outputs (currently, only argmax).

In addition, a modification to the mixed precision distance metric computation was applied, such that the metric computation is split into two parts - across interest points and quantized output points.

---------

Co-authored-by: Ofir Gordon <[email protected]>
  • Loading branch information
ofirgo and Ofir Gordon authored Nov 2, 2023
1 parent 4ee9b35 commit fce9b52
Show file tree
Hide file tree
Showing 20 changed files with 852 additions and 1,348 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def get_trace_hessian_calculator(self,
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_trace_hessian_calculator method.') # pragma: no cover


@abstractmethod
def to_numpy(self, tensor: Any) -> np.ndarray:
"""
Expand Down Expand Up @@ -387,21 +386,20 @@ def get_node_distance_fn(self, layer_class: type,


@abstractmethod
def is_node_compatible_for_metric_outputs(self,
node: BaseNode) -> bool:
def is_output_node_compatible_for_hessian_score_computation(self,
node: BaseNode) -> bool:
"""
Checks and returns whether the given node is compatible as output for metric computation
purposes and gradient-based weights calculation.
Checks and returns whether the given node is compatible as output for Hessian-based information computation.
Args:
node: A BaseNode object.
Returns: Whether the node is compatible as output for metric computation or not.
Returns: Whether the node is compatible as output for Hessian-based information computation.
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s is_node_compatible_for_metric_outputs method.') # pragma: no cover
f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover

@abstractmethod
def get_node_mac_operations(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,63 +13,22 @@
# limitations under the License.
# ==============================================================================
from typing import List
import numpy as np
from model_compression_toolkit.constants import EPS

from model_compression_toolkit.constants import EPS, HESSIAN_OUTPUT_ALPHA


def normalize_weights(trace_hessian_approximations: List,
outputs_indices: List[int],
alpha: float = HESSIAN_OUTPUT_ALPHA) -> List[float]:
def normalize_weights(hessian_approximations: List) -> np.ndarray:
"""
Normalize trace Hessian approximations. Output layers or layers after the model's considered output layers
are assigned a constant normalized value. Other layers' weights are normalized by dividing the
trace Hessian approximations value by the sum of all other values.
Normalize Hessian information approximations by dividing the trace Hessian approximations value by the sum of all
other values.
Args:
trace_hessian_approximations: Approximated average jacobian-based weights for each interest point.
outputs_indices: Indices of all nodes considered as outputs.
alpha: Multiplication factor.
hessian_approximations: Approximated average Hessian-based scores for each interest point.
Returns:
Normalized list of trace Hessian approximations for each interest point.
"""
if len(trace_hessian_approximations)==1:
return [1.]

sum_without_outputs = sum(
[trace_hessian_approximations[i] for i in range(len(trace_hessian_approximations)) if i not in outputs_indices])
normalized_grads_weights = [_get_normalized_weight(grad,
i,
sum_without_outputs,
outputs_indices,
alpha)
for i, grad in enumerate(trace_hessian_approximations)]

return normalized_grads_weights


def _get_normalized_weight(grad: float,
i: int,
sum_without_outputs: float,
outputs_indices: List[int],
alpha: float) -> float:
Normalized list of Hessian info approximations for each interest point.
"""
Normalizes the node's trace Hessian approximation value. If it is an output or output
replacement node than the normalized value is a constant, otherwise, it is normalized
by dividing with the sum of all trace Hessian approximations values.
Args:
grad: The approximation value.
i: The index of the node in the sorted interest points list.
sum_without_outputs: The sum of all approximations of nodes that are not considered outputs.
outputs_indices: A list of indices of nodes that consider outputs.
alpha: A multiplication factor.
scores_vec = np.asarray(hessian_approximations)

Returns: A normalized trace Hessian approximation.
"""
return scores_vec / (np.sum(scores_vec) + EPS)

if i in outputs_indices:
return alpha / len(outputs_indices)
else:
return ((1 - alpha) * grad / (sum_without_outputs + EPS))
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def __init__(self,
"""
self.graph = graph

for output_node in graph.get_outputs():
if not fw_impl.is_node_compatible_for_metric_outputs(output_node.node):
Logger.error(f"All graph outputs should support metric outputs, but node {output_node.node} was found with layer type {output_node.node.type}.")
if not fw_impl.is_output_node_compatible_for_hessian_score_computation(output_node.node):
Logger.error(f"All graph outputs should support Hessian computation, but node {output_node.node} "
f"was found with layer type {output_node.node.type}. "
f"Try to run MCT without Hessian info computation.")

self.input_images = fw_impl.to_tensor(input_images)
self.num_iterations_for_approximation = num_iterations_for_approximation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(self,
configuration_overwrite: List[int] = None,
num_interest_points_factor: float = 1.0,
use_grad_based_weights: bool = True,
output_grad_factor: float = 0.1,
norm_weights: bool = True,
refine_mp_solution: bool = True,
metric_normalization_threshold: float = 1e10):
Expand All @@ -46,8 +45,7 @@ def __init__(self,
num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
use_grad_based_weights (bool): Whether to use gradient-based weights for weighted average distance metric computation.
output_grad_factor (float): A tuning parameter to be used for gradient-based weights.
use_grad_based_weights (bool): Whether to use Hessian-based scores for weighted average distance metric computation.
norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not.
metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues.
Expand All @@ -67,13 +65,8 @@ def __init__(self,
self.num_interest_points_factor = num_interest_points_factor

self.use_grad_based_weights = use_grad_based_weights
self.output_grad_factor = output_grad_factor
self.norm_weights = norm_weights

if use_grad_based_weights is True:
Logger.info(f"Using gradient-based weights for mixed-precision distance metric with tuning factor "
f"{output_grad_factor}")

self.metric_normalization_threshold = metric_normalization_threshold


Expand Down
Loading

0 comments on commit fce9b52

Please sign in to comment.