Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace max tensor with max cut #1295

Merged
merged 15 commits into from
Dec 25, 2024
15 changes: 8 additions & 7 deletions model_compression_toolkit/core/common/fusion/graph_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
The fusion process involves:
1. Creating new fused nodes to represent these groups.
2. Updating the graph structure to replace the original nodes with fused nodes.
3. Maintaining mapping mapping of original node names to their fused node names.
3. Maintaining mapping of original node names to their fused node names.

Args:
graph: Graph to sue its nodes.
graph: Graph to fuse its nodes.
elad-c marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Mapping of original node names to their fused node names
Expand All @@ -54,7 +54,8 @@ def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
fused_nodes_mapping[node.name] = new_fused_node.name
return fused_nodes_mapping

def _create_fused_node(self, nodes: List[BaseNode]) -> BaseNode:
@staticmethod
def _create_fused_node(nodes: List[BaseNode]) -> BaseNode:
"""
Create a new node that represents the fusion of the given nodes.

Expand All @@ -79,10 +80,10 @@ def _create_fused_node(self, nodes: List[BaseNode]) -> BaseNode:

return fused_node

def _replace_nodes_with_fused_node(self,
graph: Graph,
nodes_to_fuse: List[BaseNode],
fused_node: BaseNode):
@staticmethod
def _replace_nodes_with_fused_node(graph: Graph,
nodes_to_fuse: List[BaseNode],
fused_node: BaseNode):
"""
Replace the specified nodes in the graph with a new fused node.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def compute_graph_max_cut(memory_graph: MemoryGraph,
estimate = (u_bound + l_bound) / 2
schedule, max_cut_size, cuts = max_cut_astar.solve(estimate_factor=estimate, iter_limit=astar_n_iter)
if schedule is None:
return last_result
l_bound = estimate
else:
u_bound = min(estimate, max_cut_size)
last_result = (schedule, max_cut_size, cuts)

next_u_bound = min(estimate, max_cut_size)
last_result = (schedule, max_cut_size, cuts)

if l_bound * (1 + eps) >= next_u_bound:
return last_result
if l_bound * (1 + eps) >= u_bound:
return last_result

it += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
cut_route = routes[next_cut]

if next_cut == self.target_cut:
# TODO maxcut: Why do we filter the cuts (cut_route) but not the max cut size (cut_sost).
# This is a mismatch between max_cut and max(cuts).
# Also, unfiltered cut_route seems perfect, including input and output tensor sizes of current op.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regarding the first question - seems like a bug, the max cut size should match the size of the remained memory elements in the cut.
regarding the filtering - it is possible that there is an issue here, but basically, it is not that we are removing input tensor sizes is that we remove memory elements of operations that have no other operations that are dependent on them that hasn't been computed yet. So we need to look further into it before just removing the filtering

return self._remove_dummys_from_path(cut_route[0].op_order), cut_cost,\
list(set([self._remove_dummys_from_cut(self.clean_memory_for_next_step(c)) for c in cut_route]))

Expand All @@ -178,7 +181,8 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
cost = self.accumulate(cut_cost, c.memory_size())
if c not in open_list:
self._update_expanded_node(c, cost, cut_route, open_list, costs, routes)
elif self.ordering(cost, costs[c]):
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
elif self.ordering(cost, costs[c]): # pragma: no cover
# If we already saw this cut during the search with a larger cost, then we want to update the order
# of the schedule in the cut
# Remove call - removes the cut with the same memory elements but different ordering from open
Expand All @@ -187,7 +191,8 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
self._update_expanded_node(c, cost, cut_route, open_list, costs, routes)

# Halt or No Solution
return None, 0, None
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
return None, 0, None # pragma: no cover

@staticmethod
def _update_expanded_node(cut: Cut, cost: float, route: List[Cut], open_list: List[Cut],
Expand Down Expand Up @@ -223,8 +228,7 @@ def _get_cut_to_expand(self, open_list: List[Cut], costs: Dict[Cut, float], rout

"""
ordered_cuts_list = sorted(open_list,
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate_factor)), len(routes[c])),
reverse=False)
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate_factor)), -len(routes[c])))

assert len(ordered_cuts_list) > 0
return ordered_cuts_list[0]
Expand Down Expand Up @@ -349,7 +353,8 @@ def ordering(cost_1, cost_2) -> bool:
Returns: True if the first cost is smaller than the second one, else otherwise.

"""
return cost_1 < cost_2
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
return cost_1 < cost_2 # pragma: no cover

def estimate(self, cut: Cut, estimate_factor: float) -> float:
"""
Expand Down Expand Up @@ -377,9 +382,10 @@ def get_init_estimate_factor(memory_graph: MemoryGraph) -> float:
Returns: An initial estimate value.

"""
l_bound = memory_graph.memory_lbound_single_op
u_bound = 2 * sum([t.total_size for t in memory_graph.b_nodes]) - l_bound
return (u_bound + l_bound) / 2
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
l_bound = memory_graph.memory_lbound_single_op # pragma: no cover
u_bound = 2 * sum([t.total_size for t in memory_graph.b_nodes]) - l_bound # pragma: no cover
return (u_bound + l_bound) / 2 # pragma: no cover

@staticmethod
def _remove_dummys_from_path(path: List[BaseNode]) -> List[BaseNode]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ def __init__(self, shape: Tuple[Any], node_name: str, node_output_index: int, in
init_size_to_zero: Whether to initialize the memory tensor size to 0 or not.
"""

self.shape = shape[1:] # remove batch size (first element) from output shape
# remove batch size (first element) from output shape. If the shape is a list then remove the first
# axis. If shape a vector (e.g. output of size) then set the shape minus 1 to ignore the batch value.
if len(shape) == 1:
self.shape = [] if shape[0] is None else [shape[0] - 1]
else:
self.shape = shape[1:]
# The total size of a tensor is considered to be the number of elements in the tensor
self.total_size = self._get_tensor_total_size() if not init_size_to_zero else 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
from typing import List
from operator import getitem

from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.edge import EDGE_SOURCE_INDEX
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(self, model_graph: Graph):
tensor_to_node = []

for n in nodes:
n_outputs = [n.output_shape] if isinstance(n.output_shape, tuple) else n.output_shape
n_outputs = n.output_shape if isinstance(n.output_shape[0], (tuple, list)) else [n.output_shape]

out_edges = model_graph.out_edges(n, sort_by_attr=EDGE_SOURCE_INDEX)

for i, ot in enumerate(n_outputs):
Expand All @@ -54,7 +56,16 @@ def __init__(self, model_graph: Graph):
# Add memory tensor as current node's output
node_to_tensor.append((n, memory_tensor))

ot_edges = [oe for oe in out_edges if oe.source_index == i]
# TODO maxcut: refactor this code. it handles split->getitem generated by fx.
ot_edges = []
for oe in out_edges:
if oe.sink_node.type is getitem and len(oe.sink_node.op_call_args) == 1 and isinstance(oe.sink_node.op_call_args[0], int):
source_index = oe.sink_node.op_call_args[0]
else:
source_index = oe.source_index
if source_index == i:
ot_edges.append(oe)

for oe in ot_edges:
# Add current memory tensor as input to current node's successors
tensor_to_node.append((memory_tensor, oe.sink_node))
Expand All @@ -71,6 +82,7 @@ def __init__(self, model_graph: Graph):
inputs_tensors_memory = [sum([t.total_size for t in self.operation_node_children(n)])
for n in nodes if n in model_graph.get_inputs()]

# TODO maxcut: why both inputs and outputs of each nodes, while the A* solves for node outputs only???
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to deeply get into the algorithm to understand this.
If it is not giving you trouble then leave it for now.
Or, create some unittest that you for sure know what the result should be and see if changing this code according to what you think it should be changes the result.

nodes_total_memory = [sum([t.total_size for t in self.operation_node_children(n)] +
[t.total_size for t in self.operation_node_parents(n)])
for n in nodes if n not in model_graph.get_inputs()]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
VirtualSplitWeightsNode, VirtualSplitActivationNode
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget, ResourceUtilization
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import RuFunctions
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric, calc_graph_cuts
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import Cut
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation

Expand All @@ -40,7 +42,7 @@ def __init__(self,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
sensitivity_evaluator: SensitivityEvaluation,
ru_functions: Dict[RUTarget, Tuple[MpRuMetric, MpRuAggregation]],
ru_functions: Dict[RUTarget, RuFunctions[MpRuMetric, MpRuAggregation]],
target_resource_utilization: ResourceUtilization,
original_graph: Graph = None):
"""
Expand All @@ -65,8 +67,11 @@ def __init__(self,
self.sensitivity_evaluator = sensitivity_evaluator
self.layer_to_bitwidth_mapping = self.get_search_space()
self.compute_metric_fn = self.get_sensitivity_metric()
self._cuts = None

self.compute_ru_functions = ru_functions
ru_types = [ru_target for ru_target, ru_value in
target_resource_utilization.get_resource_utilization_dict().items() if ru_value < np.inf]
self.compute_ru_functions = {ru_target: ru_fn for ru_target, ru_fn in ru_functions.items() if ru_target in ru_types}
self.target_resource_utilization = target_resource_utilization
self.min_ru_config = self.graph.get_min_candidates_config(fw_info)
self.max_ru_config = self.graph.get_max_candidates_config(fw_info)
Expand All @@ -76,6 +81,17 @@ def __init__(self,
self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.graph,
original_graph=self.original_graph)

@property
def cuts(self) -> List[Cut]:
"""
Calculates graph cuts. Written as property, so it will only be calculated once and
only if cuts are needed.

"""
if self._cuts is None:
self._cuts = calc_graph_cuts(self.original_graph)
return self._cuts

def get_search_space(self) -> Dict[int, List[int]]:
"""
The search space is a mapping from a node's index to a list of integers (possible bitwidths candidates indeces
Expand Down Expand Up @@ -106,6 +122,21 @@ def get_sensitivity_metric(self) -> Callable:

return self.sensitivity_evaluator.compute_metric

def _calc_ru_fn(self, ru_target, ru_fn, mp_cfg) -> np.ndarray:
"""
Computes a resource utilization for a certain mixed precision configuration.
The method computes a resource utilization vector for specific target resource utilization.

Returns: resource utilization value.

"""
# ru_fn is a pair of resource utilization computation method and
# resource utilization aggregation method (in this method we only need the first one)
if ru_target is RUTarget.ACTIVATION:
elad-c marked this conversation as resolved.
Show resolved Hide resolved
return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl, self.cuts)
else:
return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl)

def compute_min_ru(self) -> Dict[RUTarget, np.ndarray]:
"""
Computes a resource utilization vector with the values matching to the minimal mp configuration
Expand All @@ -118,10 +149,10 @@ def compute_min_ru(self) -> Dict[RUTarget, np.ndarray]:

"""
min_ru = {}
for ru_target, ru_fns in self.compute_ru_functions.items():
# ru_fns is a pair of resource utilization computation method and
for ru_target, ru_fn in self.compute_ru_functions.items():
# ru_fns is a pair of resource utilization computation method and
# resource utilization aggregation method (in this method we only need the first one)
min_ru[ru_target] = ru_fns[0](self.min_ru_config, self.graph, self.fw_info, self.fw_impl)
min_ru[ru_target] = self._calc_ru_fn(ru_target, ru_fn, self.min_ru_config)

return min_ru

Expand Down Expand Up @@ -212,7 +243,7 @@ def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int,

"""
cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx)
return self.compute_ru_functions[target].metric_fn(cfg, self.graph, self.fw_info, self.fw_impl)
return self._calc_ru_fn(target, self.compute_ru_functions[target], cfg)

@staticmethod
def replace_config_in_index(mp_cfg: List[int], idx: int, value: int) -> List[int]:
Expand Down Expand Up @@ -241,13 +272,15 @@ def _non_configurable_nodes_ru(self) -> Dict[RUTarget, np.ndarray]:
"""

non_conf_ru_dict = {}
for target, ru_value in self.target_resource_utilization.get_resource_utilization_dict().items():
for target, ru_fns in self.compute_ru_functions.items():
# Call for the ru method of the given target - empty quantization configuration list is passed since we
# compute for non-configurable nodes
if target == RUTarget.BOPS:
ru_vector = None
elif target == RUTarget.ACTIVATION:
ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl, self.cuts)
else:
ru_vector = self.compute_ru_functions[target].metric_fn([], self.graph, self.fw_info, self.fw_impl)
ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl)

non_conf_ru_dict[target] = ru_vector

Expand All @@ -266,14 +299,15 @@ def compute_resource_utilization_for_config(self, config: List[int]) -> Resource
"""

ru_dict = {}

for ru_target, ru_fns in self.compute_ru_functions.items():
# Passing False to ru methods and aggregations to indicates that the computations
# are not for constraints setting
if ru_target == RUTarget.BOPS:
configurable_nodes_ru_vector = ru_fns[0](config, self.original_graph, self.fw_info, self.fw_impl, False)
configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl, False)
elif ru_target == RUTarget.ACTIVATION:
configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.graph, self.fw_info, self.fw_impl, self.cuts)
else:
configurable_nodes_ru_vector = ru_fns[0](config, self.original_graph, self.fw_info, self.fw_impl)
configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl)
non_configurable_nodes_ru_vector = self.non_conf_ru_dict.get(ru_target)
if non_configurable_nodes_ru_vector is None or len(non_configurable_nodes_ru_vector) == 0:
ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(configurable_nodes_ru_vector, False)
Expand Down Expand Up @@ -647,7 +681,7 @@ def get_weights_for_split_activation(self,
# It's ok, need to find the node's configuration
self.retrieve_weights_activation_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg)
else:
Logger.critical(f"Virtual graph configuration error: Expected the predecessor of node '{n.name}' to have multiple outputs when not composed with an activation node.") # pragma: no cover
Logger.critical(f"Virtual graph configuration error: Expected the predecessor of node '{weights_node.name}' to have multiple outputs when not composed with an activation node.") # pragma: no cover

def update_config_at_original_idx(self, n: BaseNode, origin_cfg_idx: int):
"""
Expand Down
Loading
Loading