Skip to content

Commit

Permalink
Support TF 2.15 (#993)
Browse files Browse the repository at this point in the history
Support TF 2.15
  • Loading branch information
elad-c authored Mar 14, 2024
1 parent cab0f25 commit 539b592
Show file tree
Hide file tree
Showing 15 changed files with 112 additions and 32 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/run_tests_python310_keras215.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Python 3.10, Keras 2.15
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *
pull_request:
branches:
- main


jobs:
run-tests:
uses: ./.github/workflows/run_keras_tests.yml
with:
python-version: "3.10"
tf-version: "2.15.*"
16 changes: 16 additions & 0 deletions .github/workflows/run_tests_python311_keras215.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Python 3.11, Keras 2.15
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *
pull_request:
branches:
- main


jobs:
run-tests:
uses: ./.github/workflows/run_keras_tests.yml
with:
python-version: "3.11"
tf-version: "2.15.*"
16 changes: 16 additions & 0 deletions .github/workflows/run_tests_python39_keras215.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Python 3.9, Keras 2.15
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *
pull_request:
branches:
- main


jobs:
run-tests:
uses: ./.github/workflows/run_keras_tests.yml
with:
python-version: "3.9"
tf-version: "2.15.*"
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
| Python 3.11 | | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch20.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch20.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml) |


| | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 |
|-------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) |
| Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) |
| Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) |
| | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 | TensorFlow 2.15 |
|-------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) |
| Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) |
| Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) |


## Supported Features
Expand Down
14 changes: 8 additions & 6 deletions model_compression_toolkit/core/common/fusion/layer_fusing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def filter_fusing_patterns(fusing_patterns: List[List[Any]], node: BaseNode, idx
fusing_patterns after filtering non-relevant fusions
"""
valid_fusing_patterns = []
for i,fusing_pattern in enumerate(fusing_patterns):
for i, fusing_pattern in enumerate(fusing_patterns):
if idx < len(fusing_pattern):
if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or fusing_pattern[idx] == node.type:
if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or \
node.is_match_type(fusing_pattern[idx]):
valid_fusing_patterns.append(fusing_pattern)

# Return only valid patterns for this node
Expand All @@ -44,7 +45,7 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) ->
"""
Check if the fusion is valid: exist in fusing_patterns
Args:
fusing_patterns: supported fusings
fusing_patterns: supported fusing patterns
nodes: nodes which are participating in fusion
Returns:
whether the fusion in valid
Expand All @@ -56,8 +57,9 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) ->
if fusion_depth != len(fusing_pattern):
continue
counter = 0
for i,layer in enumerate(fusing_pattern):
if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or layer == nodes[i].type:
for i, layer in enumerate(fusing_pattern):
if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \
nodes[i].is_match_type(layer):
counter += 1
if counter == fusion_depth:
return True
Expand Down Expand Up @@ -107,7 +109,7 @@ def fusion(graph: Graph, tpc: TargetPlatformCapabilities) -> Graph:
if node in fused_nodes:
continue
# Start fusing search
fusing_nodes = [] # nodes that are candidates for participating in fusing
fusing_nodes = [] # nodes that are candidates for participating in fusing
patterns = copy.deepcopy(fusing_patterns)
next_nodes = [node]
for i in range(max_layers_fusing):
Expand Down
17 changes: 15 additions & 2 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================

import copy
from typing import Dict, Any, Tuple, List
from typing import Dict, Any, Tuple, List, Type

import numpy as np

Expand Down Expand Up @@ -556,6 +556,19 @@ def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions:
return tpc.layer2qco.get(self.type)
return tpc.tp_model.default_qco

def is_match_type(self, _type: Type) -> bool:
"""
Check if input type matches the node type, either in instance type or in type name. Checking the
name string is required because of function types changes that occurred in TF 2.15.
Args:
_type: other node type
Returns:
Whether _type matches the self node type
"""
return _type == self.type or _type.__name__ == self.type.__name__

def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool:
"""
Check if the node matches a LayerFilterParams according to its
Expand All @@ -572,7 +585,7 @@ def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool
return False

# Check the node has the same type as the layer in LayerFilterParams
if layer_filter_params.layer != self.type:
if not self.is_match_type(layer_filter_params.layer):
return False

# Get attributes from node to filter
Expand Down
4 changes: 2 additions & 2 deletions model_compression_toolkit/core/common/graph/graph_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, operation: Any):

self.operation = operation

def apply(self, input_node_object: Any) -> bool:
def apply(self, input_node_object: BaseNode) -> bool:
"""
Check if input_node_object matches the matcher condition.
Expand All @@ -47,7 +47,7 @@ def apply(self, input_node_object: Any) -> bool:
return nothing.
"""

if input_node_object.type == self.operation:
if input_node_object.is_match_type(self.operation):
return True


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def create_add_node(add_value: float,
quantization_attr={},
layer_class=TFOpLambda,
op_call_args=[np.array(add_value, dtype=np.float32).reshape([1] * len(input_shape))],
op_call_kwargs={})
op_call_kwargs={},
functional_op=tf.add)
return add_node


Expand Down Expand Up @@ -157,7 +158,8 @@ def create_pad_node(next_node_name: str,
layer_class=TFOpLambda,
op_call_args=[],
op_call_kwargs={'paddings': num_elements_to_pad,
'constant_values': value_to_pad})
'constant_values': value_to_pad},
functional_op=tf.pad)

return pad_node

Expand Down
31 changes: 22 additions & 9 deletions model_compression_toolkit/core/keras/reader/node_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any, Callable, Dict
from typing import Any, List, Dict

import tensorflow as tf
from tensorflow.python.util import tf_inspect
Expand Down Expand Up @@ -45,19 +45,32 @@
is_tensor = lambda x: isinstance(x, KerasTensor)


def get_kwargs2index(tf_func: Callable) -> Dict[str, int]:
def get_tf_function_symbols() -> List[str]:
"""
Create a list of tf function symbols, as they are created in the TFOpLambda layer. The
symbols are serializations of the function names.
Returns:
A list of TF function symbols,
"""
return [TFOpLambda(f).symbol for f in [tf.add, tf.multiply, tf.subtract, tf.divide,
tf.truediv, tf.pow, tf.matmul]]


def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
"""
Positional weights are saved according to their index in the node's call arguments, so
need to know the function arguments' names in case the weights are in the kwargs.
Args:
tf_func: functional node function.
tfoplambda_layer: TFOpLambda layer.
Returns:
A dictionary with argument number and index: {arg_name: arg_index}.
"""
if tf_func in [tf.add, tf.subtract, tf.divide, tf.truediv, tf.multiply, tf.pow,
tf.matmul, tf.image.crop_and_resize, tf.image.combined_non_max_suppression]:
return {arg_name: i for i, arg_name in enumerate(tf_inspect.getfullargspec(tf_func).args)}
if tfoplambda_layer.function in [tf.add, tf.subtract, tf.divide, tf.truediv, tf.multiply, tf.pow,
tf.matmul, tf.image.crop_and_resize, tf.image.combined_non_max_suppression] or \
tfoplambda_layer.symbol in ['__operators__.add', 'math.add', 'math.multiply', 'linalg.matmul', 'concat']:
return {arg_name: i for i, arg_name in enumerate(tf_inspect.getfullargspec(tfoplambda_layer.function).args)}
else:
return {}

Expand Down Expand Up @@ -110,7 +123,7 @@ def build_node(node: KerasNode,
# a flag to indicate that.
inputs_as_list = __is_functional_inputs_a_list(op_call_args)

kwarg2index = get_kwargs2index(keras_layer.function)
kwarg2index = get_kwargs2index(keras_layer)

# Functional nodes do not have weights, but may have constants in their call_args and\or
# call kwargs. Therefore, we extract these constants and save them in the node's weights as
Expand All @@ -122,10 +135,10 @@ def build_node(node: KerasNode,
Logger.error('Functional nodes are not expected to have weights in framework')

# read weights from call args
tf_function_symbols = get_tf_function_symbols()
for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args):
if is_const(arg) or (
keras_layer.function in [tf.add, tf.multiply, tf.subtract, tf.divide, tf.truediv, tf.pow,
tf.matmul] and
keras_layer.symbol in tf_function_symbols and
isinstance(arg, (tuple, list))):
weights.update({i: to_numpy(arg, is_single_tensor=True)})
# remove weights and KerasTensors and weights from op_call_args
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/quantization_prep_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def quantization_preparation_runner(graph: Graph,
fw_info,
core_config.quantization_config) # Mark points for statistics collection

for _data in tqdm(representative_data_gen(), "Statistics Collection:"):
for _data in tqdm(representative_data_gen(), "Statistics Collection"):
mi.infer(_data)

if tb_w is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from keras.src.layers import Conv2D, TFOpLambda, Add, DepthwiseConv2D, Dense
else:
from keras.layers import Conv2D, TFOpLambda, Add, DepthwiseConv2D, Dense
import tensorflow as tf

from tests.keras_tests.exporter_tests.keras_fake_quant.keras_fake_quant_exporter_base_test import \
KerasFakeQuantExporterBaseTest
Expand Down Expand Up @@ -59,7 +58,7 @@ def run_checks(self):
assert self.loaded_model.layers[7].output.ref() == self.loaded_model.layers[9].input.ref()

assert isinstance(self.loaded_model.layers[10], TFOpLambda)
assert self.loaded_model.layers[10].function == tf.add
assert self.loaded_model.layers[10].symbol == 'math.add'
assert self.loaded_model.layers[10].input.ref() == self.loaded_model.layers[8].output.ref()
assert self.loaded_model.layers[10].inbound_nodes[0].call_kwargs['y'].ref() == self.loaded_model.layers[9].output.ref()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
for layer in quantized_model.layers:
if type(layer) in [layers.Conv2D, layers.DepthwiseConv2D, layers.Conv2DTranspose, layers.Dense]:
self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!')
elif isinstance(layer, TFOpLambda) and layer.function is tf.add:
elif isinstance(layer, TFOpLambda) and (layer.function is tf.add or layer.symbol == TFOpLambda(tf.add).symbol):
num_adds += 1

# check all "add"s were folded except the one with 2 tensor inputs
Expand Down
Loading

0 comments on commit 539b592

Please sign in to comment.