diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py index 823cf508c..085082a0b 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from typing import Optional, Tuple import numpy as np import tensorflow as tf @@ -30,7 +31,7 @@ from model_compression_toolkit.constants import REUSE, REUSE_GROUP from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \ KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE, \ - ACTIVATION, LINEAR + ACTIVATION, LINEAR, DATA_FORMAT, GROUPS, CHANNELS_FORMAT_FIRST, CHANNELS_FORMAT_LAST def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray: @@ -136,35 +137,35 @@ def substitute(self, conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2], ACTIVATION: LINEAR} if len(conv_func_node.op_call_args) > 0: Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover - if STRIDES in conv_func_node.op_call_kwargs: - strides = conv_func_node.op_call_kwargs[STRIDES] - if len(strides) == 4: - if strides[0] > 1 or strides[3] > 1: - # Non-standard strides -> skip substitution. - return graph # pragma: no cover - conv_fw_attr[STRIDES] = strides[1:3] - else: - conv_fw_attr[STRIDES] = strides - if PADDING in conv_func_node.op_call_kwargs: - padding = conv_func_node.op_call_kwargs[PADDING] - if not isinstance(padding, str): - # Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution. - return graph # pragma: no cover - conv_fw_attr[PADDING] = padding - if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None: - dilations = conv_func_node.op_call_kwargs[DILATIONS] - if isinstance(dilations, (list, tuple)) and len(dilations) == 4: - if dilations[0] > 1 or dilations[3] > 1: - # Non-standard dilations -> skip substitution. - return graph # pragma: no cover - conv_fw_attr[DILATION_RATE] = dilations[1:3] - else: - conv_fw_attr[DILATION_RATE] = dilations + + strides = self._parse_tf_stride_dilation(conv_func_node, STRIDES) + if strides is None: + # Non-standard strides -> skip substitution. + return graph + conv_fw_attr[STRIDES] = strides + + padding = conv_func_node.op_call_kwargs.get(PADDING) or 'VALID' + if not isinstance(padding, str): + # Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution. + return graph # pragma: no cover + conv_fw_attr[PADDING] = padding + + dilations = self._parse_tf_stride_dilation(conv_func_node, DILATIONS) + if dilations is None: + # Non-standard dilations -> skip substitution. + return graph + conv_fw_attr[DILATION_RATE] = dilations + if b is None: conv_fw_attr[USE_BIAS] = False else: weights[BIAS] = b + data_format = conv_func_node.op_call_kwargs.get(DATA_FORMAT, 'NHWC') + conv_fw_attr[DATA_FORMAT] = {'NHWC': CHANNELS_FORMAT_LAST, 'NCHW': CHANNELS_FORMAT_FIRST}[data_format] + + conv_fw_attr[GROUPS] = 1 + _reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group} conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape, weights, Conv2D, **_reuse_params) @@ -172,6 +173,31 @@ def substitute(self, replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None) return graph + def _parse_tf_stride_dilation(self, node, key) -> Optional[Tuple[int, int]]: + """ + Extract stride/dilation param from tf node and convert it to keras format (suitable for Conv2D). + + Args: + node: node + key: param key + + Returns: + Parsed value or None if non-standard. + """ + v = node.op_call_kwargs.get(key) + if v is None: + return 1, 1 + if isinstance(v, int): + return v, v + if len(v) == 1: + return v[0], v[0] + if len(v) == 4: + if v[0] > 1 and v[-1] > 1: + return None + else: + return v[1:3] + return tuple(v) + class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution): """ diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/conv_func_substitutions_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/conv_func_substitutions_test.py index f5c24a793..381d330cb 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/conv_func_substitutions_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/conv_func_substitutions_test.py @@ -39,6 +39,9 @@ class ConvFuncSubstitutionsTest(BaseKerasFeatureNetworkTest): + def __init__(self, unit_test): + super().__init__(unit_test, input_shape=(32, 32, 3)) + def get_tpc(self): tp = generate_test_tp_model({'enable_weights_quantization': False, 'enable_activation_quantization': False}) @@ -67,6 +70,18 @@ def create_networks(self): x = tf.nn.convolution(x, np.random.random((3, 3, 2, 4)).astype(np.float32), [2, 1], padding='SAME') x = tf.nn.bias_add(x, np.random.random((4,)).astype(np.float32)) + + # default values and various formats + x = tf.nn.conv2d(x, np.random.random((3, 3, 4, 8)), 1, 'VALID') + x = tf.nn.conv2d(x, np.random.random((3, 3, 8, 16)), strides=[1], padding='SAME', dilations=1) + x = tf.nn.conv2d(x, np.random.random((3, 3, 16, 8)), strides=[1, 1], padding='VALID', dilations=[1]) + x = tf.nn.conv2d(x, filters=np.random.random((3, 3, 8, 4)), strides=[1, 1], padding='SAME', dilations=[1, 1]) + + x = tf.nn.convolution(x, np.random.random((3, 3, 4, 16)).astype(np.float32)) + x = tf.nn.convolution(x, np.random.random((3, 3, 16, 32)).astype(np.float32), strides=[1], padding='SAME', dilations=1) + x = tf.nn.convolution(x, np.random.random((3, 3, 32, 8)).astype(np.float32), strides=[1, 1], padding='VALID', dilations=[1]) + x = tf.nn.convolution(x, filters=np.random.random((3, 3, 8, 4)).astype(np.float32), strides=[1, 1], padding='VALID', dilations=[1, 1]) + return tf.keras.Model(inputs=_in, outputs=x) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): @@ -75,7 +90,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= cs = cosine_similarity(out_float.numpy(), out_quant.numpy()) self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check: {cs}') - self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, Conv2D)) == 4, + self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, Conv2D)) == 12, "Not all conv functions were substituted.") self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, DepthwiseConv2D)) == 2, "Not all dw-conv functions were substituted.") diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py index d6cc325fb..cae35085f 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py @@ -54,6 +54,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= y = float_model.predict(input_x) y_hat = quantized_model.predict(input_x) self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + # FIXME this doesn't test anything, the number of quantized convs in the network is exactly 0. Even if it + # looked at correct layers it hardly checks anything. self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, layers.Conv2D)]) < len([l for l in float_model.layers if isinstance(l, layers.Conv2D)]), msg=f'fail number of layers should decrease!') cs = cosine_similarity(y, y_hat) self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}') @@ -75,6 +77,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= if type(layer) == layers.Conv2D: self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!') + class ThreeConv2DCollapsingTest(BaseConv2DCollapsingTest): def __init__(self, unit_test): super().__init__(unit_test) @@ -107,9 +110,35 @@ def create_networks(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): super().compare(quantized_model, float_model, input_x, quantization_info) - for layer in quantized_model.layers: - if type(layer) == layers.Conv2D: - self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!') + convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)] + self.unit_test.assertTrue(len(convs) == 1) + for layer in convs: + self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!') + + +class FuncConvCollapsingTest(FourConv2DCollapsingTest): + def create_networks(self): + # Tests the combination of functional conv to Conv2D substitution with linear collapsing + # (in case of default values, tf layer doesn't contain these attributes, and they must be added explicitly + # to node's attributes dict, which is not covered by substitution test) + h, w, c = self.get_input_shapes()[0][1:] + inputs = layers.Input(shape=(h, w, c)) + x = tf.nn.conv2d(inputs, tf.random.uniform((3, 3, c, 16)), 1, 'SAME') + x = tf.nn.convolution(x, tf.random.uniform((1, 1, 16, 8))) + x = tf.nn.relu(x) + x = tf.nn.convolution(x, tf.random.uniform((3, 3, 8, 32))) + y = tf.nn.conv2d(x, tf.random.uniform((1, 1, 32, 4)), 1, 'VALID') + return tf.keras.models.Model(inputs=inputs, outputs=y) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)] + self.unit_test.assertTrue(len(convs) == 2) + + y = float_model.predict(input_x) + y_hat = quantized_model.predict(input_x) + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + cs = cosine_similarity(y, y_hat) + self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}') class SixConv2DCollapsingTest(BaseConv2DCollapsingTest): diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index b11be7c04..1c0e8ae2e 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -56,7 +56,8 @@ from tests.keras_tests.feature_networks_tests.feature_networks.input_scaling_test import InputScalingDenseTest, \ InputScalingConvTest, InputScalingDWTest, InputScalingZeroPadTest from tests.keras_tests.feature_networks_tests.feature_networks.linear_collapsing_test import TwoConv2DCollapsingTest, \ - ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest + ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest, \ + FuncConvCollapsingTest from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \ LUTActivationQuantizerTest from tests.keras_tests.feature_networks_tests.feature_networks.manual_bit_selection import ManualBitWidthSelectionTest, \ @@ -605,6 +606,7 @@ def test_linear_collapsing(self): FourConv2DCollapsingTest(self).run_test() SixConv2DCollapsingTest(self).run_test() Op2DAddConstCollapsingTest(self).run_test() + FuncConvCollapsingTest(self).run_test() def test_const_quantization(self): c = (np.ones((32, 32, 16)) + np.random.random((32, 32, 16))).astype(np.float32)