Skip to content

Commit

Permalink
improve readability
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Dec 1, 2024
1 parent 52bb1c5 commit ff3c49e
Show file tree
Hide file tree
Showing 12 changed files with 121 additions and 124 deletions.
14 changes: 9 additions & 5 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,14 @@ def get_weights_configurable_nodes(self,
"""
# configurability is only relevant for kernel attribute quantization
potential_conf_nodes = [n for n in list(self) if fw_info.is_kernel_op(n.type)]
return list(filter(lambda n: n.is_weights_quantization_enabled(fw_info.get_kernel_op_attributes(n.type)[0])
and not n.is_all_weights_candidates_equal(fw_info.get_kernel_op_attributes(n.type)[0])
and (not n.reuse or include_reused_nodes), potential_conf_nodes))

def is_configurable(n):
kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
return (n.is_weights_quantization_enabled(kernel_attr) and
not n.is_all_weights_candidates_equal(kernel_attr) and
(not n.reuse or include_reused_nodes))

return [n for n in potential_conf_nodes if is_configurable(n)]

def get_sorted_weights_configurable_nodes(self,
fw_info: FrameworkInfo,
Expand All @@ -571,8 +576,7 @@ def get_activation_configurable_nodes(self) -> List[BaseNode]:
Returns:
A list of nodes that their activation can be configured (namely, has one or more activation qc candidate).
"""
return list(filter(lambda n: n.is_activation_quantization_enabled()
and not n.is_all_activation_candidates_equal(), list(self)))
return [n for n in list(self) if n.is_activation_quantization_enabled() and not n.is_all_activation_candidates_equal()]

def get_sorted_activation_configurable_nodes(self) -> List[BaseNode]:
"""
Expand Down
5 changes: 2 additions & 3 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,8 @@ def is_all_weights_candidates_equal(self, attr: str) -> bool:
"""
# note that if the given attribute name does not exist in the node's attributes mapping,
# the inner method would log an exception.
return all(attr_candidate ==
self.candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(attr)
for attr_candidate in self.get_all_weights_attr_candidates(attr))
candidates = self.get_all_weights_attr_candidates(attr)
return all(candidate == candidates[0] for candidate in candidates[1:])

def has_kernel_weight_to_quantize(self, fw_info):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,41 +40,38 @@ def filter_candidates_for_mixed_precision(graph: Graph,
"""

no_total_restrictions = (target_resource_utilization.total_memory == np.inf and
target_resource_utilization.bops == np.inf)
tru = target_resource_utilization
if tru.total_mem_restricted() or tru.bops_restricted():
return

if target_resource_utilization.weights_memory < np.inf:
if target_resource_utilization.activation_memory == np.inf and no_total_restrictions:
# Running mixed precision for weights compression only -
# filter out candidates activation only configurable node
weights_conf = graph.get_weights_configurable_nodes(fw_info)
for n in graph.get_activation_configurable_nodes():
if n not in weights_conf:
base_cfg_nbits = n.get_qco(tpc).base_config.activation_n_bits
filtered_conf = [c for c in n.candidates_quantization_cfg if
c.activation_quantization_cfg.enable_activation_quantization and
c.activation_quantization_cfg.activation_n_bits == base_cfg_nbits]
if tru.weight_restricted() and not tru.activation_restricted():
# Running mixed precision for weights compression only -
# filter out candidates activation only configurable node
weights_conf = graph.get_weights_configurable_nodes(fw_info)
nodes = [n for n in graph.get_activation_configurable_nodes() if n not in weights_conf]
for n in nodes:
base_cfg_nbits = n.get_qco(tpc).base_config.activation_n_bits
filtered_conf = [c for c in n.candidates_quantization_cfg if
c.activation_quantization_cfg.enable_activation_quantization and
c.activation_quantization_cfg.activation_n_bits == base_cfg_nbits]

if len(filtered_conf) != 1:
Logger.critical(f"Running weights only mixed precision failed on layer {n.name} with multiple "
f"activation quantization configurations.") # pragma: no cover
n.candidates_quantization_cfg = filtered_conf
if len(filtered_conf) != 1:
Logger.critical(f"Running weights only mixed precision failed on layer {n.name} with multiple "
f"activation quantization configurations.") # pragma: no cover
n.candidates_quantization_cfg = filtered_conf

elif target_resource_utilization.activation_memory < np.inf:
if target_resource_utilization.weights_memory == np.inf and no_total_restrictions:
# Running mixed precision for activation compression only -
# filter out candidates weights only configurable node
activation_conf = graph.get_activation_configurable_nodes()
for n in graph.get_weights_configurable_nodes(fw_info):
if n not in activation_conf:
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
base_cfg_nbits = n.get_qco(tpc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
filtered_conf = [c for c in n.candidates_quantization_cfg if
c.weights_quantization_cfg.get_attr_config(
kernel_attr).enable_weights_quantization and
c.weights_quantization_cfg.get_attr_config(
kernel_attr).weights_n_bits == base_cfg_nbits]
if len(filtered_conf) != 1:
Logger.critical(f"Running activation only mixed precision failed on layer {n.name} with multiple "
f"weights quantization configurations.") # pragma: no cover
n.candidates_quantization_cfg = filtered_conf
elif tru.activation_restricted() and not tru.weight_restricted():
# Running mixed precision for activation compression only -
# filter out candidates weights only configurable node
activation_conf = graph.get_activation_configurable_nodes()
nodes = [n for n in graph.get_weights_configurable_nodes(fw_info) if n not in activation_conf]
for n in nodes:
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
base_cfg_nbits = n.get_qco(tpc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
filtered_conf = [c for c in n.candidates_quantization_cfg if
c.weights_quantization_cfg.get_attr_config(kernel_attr).enable_weights_quantization and
c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == base_cfg_nbits]
if len(filtered_conf) != 1:
Logger.critical(f"Running activation only mixed precision failed on layer {n.name} with multiple "
f"weights quantization configurations.") # pragma: no cover
n.candidates_quantization_cfg = filtered_conf
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,17 @@ def search_bit_width(graph_to_search_cfg: Graph,

# Set graph for MP search
graph = copy.deepcopy(graph_to_search_cfg) # Copy graph before searching
if target_resource_utilization.bops < np.inf:
if target_resource_utilization.bops_restricted():
# Since Bit-operations count target resource utilization is set, we need to reconstruct the graph for the MP search
graph = substitute(graph, fw_impl.get_substitutions_virtual_weights_activation_coupling())

# If we only run weights compression with MP than no need to consider activation quantization when computing the
# MP metric (it adds noise to the computation)
disable_activation_for_metric = (target_resource_utilization.weights_memory < np.inf and
(target_resource_utilization.activation_memory == np.inf and
target_resource_utilization.total_memory == np.inf and
target_resource_utilization.bops == np.inf)) or graph_to_search_cfg.is_single_activation_cfg()
tru = target_resource_utilization
weight_only_restricted = tru.weight_restricted() and not (tru.activation_restricted() or
tru.total_mem_restricted() or
tru.bops_restricted())
disable_activation_for_metric = weight_only_restricted or graph_to_search_cfg.is_single_activation_cfg()

# Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
Expand All @@ -117,11 +118,10 @@ def search_bit_width(graph_to_search_cfg: Graph,
target_resource_utilization,
original_graph=graph_to_search_cfg)

if search_method in search_methods: # Get a specific search function
search_method_fn = search_methods.get(search_method)
else:
raise NotImplemented # pragma: no cover
if search_method not in search_methods:
raise NotImplementedError() # pragma: no cover

search_method_fn = search_methods[search_method]
# Search for the desired mixed-precision configuration
result_bit_cfg = search_method_fn(search_manager,
target_resource_utilization)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,8 @@ def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int,
Returns: Node's resource utilization vector.
"""
return self.compute_ru_functions[target][0](
self.replace_config_in_index(
self.min_ru_config,
conf_node_idx,
candidate_idx),
self.graph,
self.fw_info,
self.fw_impl)
cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx)
return self.compute_ru_functions[target].metric_fn(cfg, self.graph, self.fw_info, self.fw_impl)

@staticmethod
def replace_config_in_index(mp_cfg: List[int], idx: int, value: int) -> List[int]:
Expand Down Expand Up @@ -253,7 +247,7 @@ def _non_configurable_nodes_ru(self) -> Dict[RUTarget, np.ndarray]:
if target == RUTarget.BOPS:
ru_vector = None
else:
ru_vector = self.compute_ru_functions[target][0]([], self.graph, self.fw_info, self.fw_impl)
ru_vector = self.compute_ru_functions[target].metric_fn([], self.graph, self.fw_info, self.fw_impl)

non_conf_ru_dict[target] = ru_vector

Expand Down Expand Up @@ -282,9 +276,9 @@ def compute_resource_utilization_for_config(self, config: List[int]) -> Resource
configurable_nodes_ru_vector = ru_fns[0](config, self.original_graph, self.fw_info, self.fw_impl)
non_configurable_nodes_ru_vector = self.non_conf_ru_dict.get(ru_target)
if non_configurable_nodes_ru_vector is None or len(non_configurable_nodes_ru_vector) == 0:
ru_ru = self.compute_ru_functions[ru_target][1](configurable_nodes_ru_vector, False)
ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(configurable_nodes_ru_vector, False)
else:
ru_ru = self.compute_ru_functions[ru_target][1](
ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(
np.concatenate([configurable_nodes_ru_vector, non_configurable_nodes_ru_vector]), False)

ru_dict[ru_target] = ru_ru[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ def __repr__(self):
f"Total_memory: {self.total_memory}, " \
f"BOPS: {self.bops}"

def weight_restricted(self):
return self.weights_memory < np.inf

def activation_restricted(self):
return self.activation_memory < np.inf

def total_mem_restricted(self):
return self.total_memory < np.inf

def bops_restricted(self):
return self.bops < np.inf

def get_resource_utilization_dict(self) -> Dict[RUTarget, float]:
"""
Returns: a dictionary with the ResourceUtilization object's values for each resource utilization target.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import NamedTuple

from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric
Expand All @@ -20,7 +22,12 @@
# When adding a RUTarget that we want to consider in our mp search,
# a matching pair of resource_utilization_tools computation function and a resource_utilization_tools
# aggregation function should be added to this dictionary
ru_functions_mapping = {RUTarget.WEIGHTS: (MpRuMetric.WEIGHTS_SIZE, MpRuAggregation.SUM),
RUTarget.ACTIVATION: (MpRuMetric.ACTIVATION_OUTPUT_SIZE, MpRuAggregation.MAX),
RUTarget.TOTAL: (MpRuMetric.TOTAL_WEIGHTS_ACTIVATION_SIZE, MpRuAggregation.TOTAL),
RUTarget.BOPS: (MpRuMetric.BOPS_COUNT, MpRuAggregation.SUM)}
class RuFunctions(NamedTuple):
metric_fn: MpRuMetric
aggregate_fn: MpRuAggregation


ru_functions_mapping = {RUTarget.WEIGHTS: RuFunctions(MpRuMetric.WEIGHTS_SIZE, MpRuAggregation.SUM),
RUTarget.ACTIVATION: RuFunctions(MpRuMetric.ACTIVATION_OUTPUT_SIZE, MpRuAggregation.MAX),
RUTarget.TOTAL: RuFunctions(MpRuMetric.TOTAL_WEIGHTS_ACTIVATION_SIZE, MpRuAggregation.TOTAL),
RUTarget.BOPS: RuFunctions(MpRuMetric.BOPS_COUNT, MpRuAggregation.SUM)}
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ def _add_set_of_ru_constraints(search_manager: MixedPrecisionSearchManager,
# search_manager.compute_ru_functions contains a pair of ru_metric and ru_aggregation for each ru target
# get aggregated ru, considering both configurable and non-configurable nodes
if non_conf_ru_vector is None or len(non_conf_ru_vector) == 0:
aggr_ru = search_manager.compute_ru_functions[target][1](ru_sum_vector)
aggr_ru = search_manager.compute_ru_functions[target].aggregate_fn(ru_sum_vector)
else:
aggr_ru = search_manager.compute_ru_functions[target][1](np.concatenate([ru_sum_vector, non_conf_ru_vector]))
aggr_ru = search_manager.compute_ru_functions[target].aggregate_fn(np.concatenate([ru_sum_vector, non_conf_ru_vector]))

for v in aggr_ru:
if isinstance(v, float):
Expand Down Expand Up @@ -261,9 +261,7 @@ def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
Logger.info('Starting to evaluate metrics')
layer_to_metrics_mapping = {}

is_bops_target_resource_utilization = target_resource_utilization.bops < np.inf

if is_bops_target_resource_utilization:
if target_resource_utilization.bops_restricted():
origin_max_config = search_manager.config_reconstruction_helper.reconstruct_config_from_virtual_graph(search_manager.max_ru_config)
max_config_value = search_manager.compute_metric_fn(origin_max_config)
else:
Expand All @@ -284,7 +282,7 @@ def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
mp_model_configuration[node_idx] = bitwidth_idx

# Build a distance matrix using the function we got from the framework implementation.
if is_bops_target_resource_utilization:
if target_resource_utilization.bops_restricted():
# Reconstructing original graph's configuration from virtual graph's configuration
origin_mp_model_configuration = \
search_manager.config_reconstruction_helper.reconstruct_config_from_virtual_graph(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self,
quant_config.num_interest_points_factor)

# We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed
# beacause hessian weights already do normalization.
# because hessian weights already do normalization.
use_normalized_mse = self.quant_config.use_hessian_based_scores is False
self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points, use_normalized_mse)

Expand All @@ -116,14 +116,11 @@ def __init__(self,
# Build images batches for inference comparison
self.images_batches = self._get_images_batches(quant_config.num_of_images)

# Get baseline model inference on all samples
self.baseline_tensors_list = [] # setting from outside scope

# Casting images tensors to the framework tensor type.
self.images_batches = list(map(lambda in_arr: self.fw_impl.to_tensor(in_arr), self.images_batches))
self.images_batches = [self.fw_impl.to_tensor(img) for img in self.images_batches]

# Initiating baseline_tensors_list since it is not initiated in SensitivityEvaluationManager init.
self._init_baseline_tensors_list()
self.baseline_tensors_list = self._init_baseline_tensors_list()

# Computing Hessian-based scores for weighted average distance metric computation (only if requested),
# and assigning distance_weighting method accordingly.
Expand Down Expand Up @@ -193,11 +190,9 @@ def compute_metric(self,

def _init_baseline_tensors_list(self):
"""
Evaluates the baseline model on all images and saves the obtained lists of tensors in a list for later use.
Initiates a class variable self.baseline_tensors_list
Evaluates the baseline model on all images and returns the obtained lists of tensors in a list for later use.
"""
self.baseline_tensors_list = [self.fw_impl.to_numpy(self.fw_impl.sensitivity_eval_inference(self.baseline_model,
images))
return [self.fw_impl.to_numpy(self.fw_impl.sensitivity_eval_inference(self.baseline_model, images))
for images in self.images_batches]

def _build_models(self) -> Any:
Expand Down Expand Up @@ -454,7 +449,7 @@ def get_mp_interest_points(graph: Graph,
"""
sorted_nodes = graph.get_topo_sorted_nodes()
ip_nodes = list(filter(lambda n: interest_points_classifier(n), sorted_nodes))
ip_nodes = [n for n in sorted_nodes if interest_points_classifier(n)]

interest_points_nodes = bound_num_interest_points(ip_nodes, num_ip_factor)

Expand Down
Loading

0 comments on commit ff3c49e

Please sign in to comment.