From e4adf42aabc9184cbe2c811783811b9cf9a6fca4 Mon Sep 17 00:00:00 2001 From: irenab Date: Sun, 24 Nov 2024 10:14:15 +0200 Subject: [PATCH] fix tf.nn.{conv2,convolution} substitution --- .../substitutions/conv_funcs_to_layer.py | 65 +++++++++++++------ .../linear_collapsing_test.py | 44 ++++++++++++- .../test_features_runner.py | 5 +- 3 files changed, 91 insertions(+), 23 deletions(-) diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py index 823cf508c..b416635b7 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from typing import Optional, Tuple import numpy as np import tensorflow as tf @@ -30,7 +31,7 @@ from model_compression_toolkit.constants import REUSE, REUSE_GROUP from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \ KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE, \ - ACTIVATION, LINEAR + ACTIVATION, LINEAR, DATA_FORMAT, GROUPS def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray: @@ -136,35 +137,36 @@ def substitute(self, conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2], ACTIVATION: LINEAR} if len(conv_func_node.op_call_args) > 0: Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover - if STRIDES in conv_func_node.op_call_kwargs: - strides = conv_func_node.op_call_kwargs[STRIDES] - if len(strides) == 4: - if strides[0] > 1 or strides[3] > 1: - # Non-standard strides -> skip substitution. - return graph # pragma: no cover - conv_fw_attr[STRIDES] = strides[1:3] - else: - conv_fw_attr[STRIDES] = strides + + strides = self._parse_tf_2d_kwarg(conv_func_node, STRIDES) + if strides is None: + # Non-standard strides -> skip substitution. + return graph + conv_fw_attr[STRIDES] = strides + if PADDING in conv_func_node.op_call_kwargs: padding = conv_func_node.op_call_kwargs[PADDING] if not isinstance(padding, str): # Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution. return graph # pragma: no cover conv_fw_attr[PADDING] = padding - if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None: - dilations = conv_func_node.op_call_kwargs[DILATIONS] - if isinstance(dilations, (list, tuple)) and len(dilations) == 4: - if dilations[0] > 1 or dilations[3] > 1: - # Non-standard dilations -> skip substitution. - return graph # pragma: no cover - conv_fw_attr[DILATION_RATE] = dilations[1:3] - else: - conv_fw_attr[DILATION_RATE] = dilations + + dilations = self._parse_tf_2d_kwarg(conv_func_node, DILATIONS) + if dilations is None: + # Non-standard dilations -> skip substitution. + return graph + conv_fw_attr[DILATION_RATE] = dilations + if b is None: conv_fw_attr[USE_BIAS] = False else: weights[BIAS] = b + data_format = conv_func_node.op_call_kwargs.get(DATA_FORMAT, 'NHWC') + conv_fw_attr['data_format'] = {'NHWC': 'channels_last', 'NCHW': 'channels_first'}[data_format] + + conv_fw_attr[GROUPS] = 1 + _reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group} conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape, weights, Conv2D, **_reuse_params) @@ -172,6 +174,31 @@ def substitute(self, replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None) return graph + def _parse_tf_2d_kwarg(self, node, key) -> Optional[Tuple[int, int]]: + """ + Extract stride/dilation param from tf node and convert it to keras format (suitable for Conv2D). + + Args: + node: node + key: param key + + Returns: + Parsed value or None if non-standard. + """ + v = node.op_call_kwargs.get(key) + if v is None: + return 1, 1 + if isinstance(v, int): + return v, v + if len(v) == 1: + return v[0], v[0] + if len(v) == 4: + if v[0] > 1 and v[-1] > 1: + return None + else: + return v[1:3] + return tuple(v) + class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution): """ diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py index d6cc325fb..f00ea9bfe 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py @@ -54,6 +54,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= y = float_model.predict(input_x) y_hat = quantized_model.predict(input_x) self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + # FIXME this doesn't test anything, the number of quantized convs in the network is exactly 0. Even if it + # looked at correct layers it hardly checks anything. self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, layers.Conv2D)]) < len([l for l in float_model.layers if isinstance(l, layers.Conv2D)]), msg=f'fail number of layers should decrease!') cs = cosine_similarity(y, y_hat) self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}') @@ -75,6 +77,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= if type(layer) == layers.Conv2D: self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!') + class ThreeConv2DCollapsingTest(BaseConv2DCollapsingTest): def __init__(self, unit_test): super().__init__(unit_test) @@ -107,9 +110,44 @@ def create_networks(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): super().compare(quantized_model, float_model, input_x, quantization_info) - for layer in quantized_model.layers: - if type(layer) == layers.Conv2D: - self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!') + convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)] + self.unit_test.assertTrue(len(convs) == 1) + for layer in convs: + self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!') + + +class FuncConv2DCollapsingTest(FourConv2DCollapsingTest): + def create_networks(self): + # tests the combination of functional conv to Conv2D substitution with linear collapsing + h, w, c = self.get_input_shapes()[0][1:] + inputs = layers.Input(shape=(h, w, c)) + x = tf.nn.conv2d(inputs, tf.random.uniform((3, 3, c, 128)), 1, 'VALID') + x = tf.nn.conv2d(x, filters=tf.random.uniform((1, 1, 128, 64)), strides=[1], padding='SAME', dilations=1) + x = tf.nn.conv2d(x, tf.random.uniform((1, 1, 64, 64)), strides=[1, 1], padding='VALID', dilations=[1]) + y = tf.nn.conv2d(x, tf.random.uniform((1, 1, 64, 4)), strides=[1, 1], padding='SAME', dilations=[1, 1]) + return tf.keras.models.Model(inputs=inputs, outputs=y) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)] + self.unit_test.assertTrue(len(convs) == 1) + + y = float_model.predict(input_x) + y_hat = quantized_model.predict(input_x) + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + cs = cosine_similarity(y, y_hat) + self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}') + + +class FuncConvolutionCollapsingTest(FuncConv2DCollapsingTest): + def create_networks(self): + # tests the combination of functional conv to Conv2D substitution with linear collapsing + h, w, c = self.get_input_shapes()[0][1:] + inputs = layers.Input(shape=(h, w, c)) + x = tf.nn.convolution(inputs, tf.random.uniform((3, 3, c, 128))) + x = tf.nn.convolution(x, filters=tf.random.uniform((1, 1, 128, 64)), strides=[1], padding='SAME', dilations=1) + x = tf.nn.convolution(x, tf.random.uniform((1, 1, 64, 64)), strides=[1, 1], padding='VALID', dilations=[1]) + y = tf.nn.convolution(x, tf.random.uniform((1, 1, 64, 4)), strides=[1, 1], padding='VALID', dilations=[1, 1]) + return tf.keras.models.Model(inputs=inputs, outputs=y) class SixConv2DCollapsingTest(BaseConv2DCollapsingTest): diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index b11be7c04..003982326 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -56,7 +56,8 @@ from tests.keras_tests.feature_networks_tests.feature_networks.input_scaling_test import InputScalingDenseTest, \ InputScalingConvTest, InputScalingDWTest, InputScalingZeroPadTest from tests.keras_tests.feature_networks_tests.feature_networks.linear_collapsing_test import TwoConv2DCollapsingTest, \ - ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest + ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest, \ + FuncConv2DCollapsingTest, FuncConvolutionCollapsingTest from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \ LUTActivationQuantizerTest from tests.keras_tests.feature_networks_tests.feature_networks.manual_bit_selection import ManualBitWidthSelectionTest, \ @@ -605,6 +606,8 @@ def test_linear_collapsing(self): FourConv2DCollapsingTest(self).run_test() SixConv2DCollapsingTest(self).run_test() Op2DAddConstCollapsingTest(self).run_test() + FuncConv2DCollapsingTest(self).run_test() + FuncConvolutionCollapsingTest(self).run_test() def test_const_quantization(self): c = (np.ones((32, 32, 16)) + np.random.random((32, 32, 16))).astype(np.float32)