Skip to content

Commit

Permalink
fix tf.nn.{conv2d,convolution} substitution (#1275)
Browse files Browse the repository at this point in the history
* fix tf.nn.{conv2,convolution} substitution
  • Loading branch information
irenaby authored Nov 24, 2024
1 parent 432ae5b commit efd3310
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Tuple

import numpy as np
import tensorflow as tf
Expand All @@ -30,7 +31,7 @@
from model_compression_toolkit.constants import REUSE, REUSE_GROUP
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \
KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE, \
ACTIVATION, LINEAR
ACTIVATION, LINEAR, DATA_FORMAT, GROUPS, CHANNELS_FORMAT_FIRST, CHANNELS_FORMAT_LAST


def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray:
Expand Down Expand Up @@ -136,42 +137,67 @@ def substitute(self,
conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2], ACTIVATION: LINEAR}
if len(conv_func_node.op_call_args) > 0:
Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover
if STRIDES in conv_func_node.op_call_kwargs:
strides = conv_func_node.op_call_kwargs[STRIDES]
if len(strides) == 4:
if strides[0] > 1 or strides[3] > 1:
# Non-standard strides -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[STRIDES] = strides[1:3]
else:
conv_fw_attr[STRIDES] = strides
if PADDING in conv_func_node.op_call_kwargs:
padding = conv_func_node.op_call_kwargs[PADDING]
if not isinstance(padding, str):
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[PADDING] = padding
if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None:
dilations = conv_func_node.op_call_kwargs[DILATIONS]
if isinstance(dilations, (list, tuple)) and len(dilations) == 4:
if dilations[0] > 1 or dilations[3] > 1:
# Non-standard dilations -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[DILATION_RATE] = dilations[1:3]
else:
conv_fw_attr[DILATION_RATE] = dilations

strides = self._parse_tf_stride_dilation(conv_func_node, STRIDES)
if strides is None:
# Non-standard strides -> skip substitution.
return graph
conv_fw_attr[STRIDES] = strides

padding = conv_func_node.op_call_kwargs.get(PADDING) or 'VALID'
if not isinstance(padding, str):
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[PADDING] = padding

dilations = self._parse_tf_stride_dilation(conv_func_node, DILATIONS)
if dilations is None:
# Non-standard dilations -> skip substitution.
return graph
conv_fw_attr[DILATION_RATE] = dilations

if b is None:
conv_fw_attr[USE_BIAS] = False
else:
weights[BIAS] = b

data_format = conv_func_node.op_call_kwargs.get(DATA_FORMAT, 'NHWC')
conv_fw_attr[DATA_FORMAT] = {'NHWC': CHANNELS_FORMAT_LAST, 'NCHW': CHANNELS_FORMAT_FIRST}[data_format]

conv_fw_attr[GROUPS] = 1

_reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group}
conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape,
weights, Conv2D, **_reuse_params)

replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None)
return graph

def _parse_tf_stride_dilation(self, node, key) -> Optional[Tuple[int, int]]:
"""
Extract stride/dilation param from tf node and convert it to keras format (suitable for Conv2D).
Args:
node: node
key: param key
Returns:
Parsed value or None if non-standard.
"""
v = node.op_call_kwargs.get(key)
if v is None:
return 1, 1
if isinstance(v, int):
return v, v
if len(v) == 1:
return v[0], v[0]
if len(v) == 4:
if v[0] > 1 and v[-1] > 1:
return None
else:
return v[1:3]
return tuple(v)


class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@

class ConvFuncSubstitutionsTest(BaseKerasFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test, input_shape=(32, 32, 3))

def get_tpc(self):
tp = generate_test_tp_model({'enable_weights_quantization': False,
'enable_activation_quantization': False})
Expand Down Expand Up @@ -67,6 +70,18 @@ def create_networks(self):
x = tf.nn.convolution(x, np.random.random((3, 3, 2, 4)).astype(np.float32),
[2, 1], padding='SAME')
x = tf.nn.bias_add(x, np.random.random((4,)).astype(np.float32))

# default values and various formats
x = tf.nn.conv2d(x, np.random.random((3, 3, 4, 8)), 1, 'VALID')
x = tf.nn.conv2d(x, np.random.random((3, 3, 8, 16)), strides=[1], padding='SAME', dilations=1)
x = tf.nn.conv2d(x, np.random.random((3, 3, 16, 8)), strides=[1, 1], padding='VALID', dilations=[1])
x = tf.nn.conv2d(x, filters=np.random.random((3, 3, 8, 4)), strides=[1, 1], padding='SAME', dilations=[1, 1])

x = tf.nn.convolution(x, np.random.random((3, 3, 4, 16)).astype(np.float32))
x = tf.nn.convolution(x, np.random.random((3, 3, 16, 32)).astype(np.float32), strides=[1], padding='SAME', dilations=1)
x = tf.nn.convolution(x, np.random.random((3, 3, 32, 8)).astype(np.float32), strides=[1, 1], padding='VALID', dilations=[1])
x = tf.nn.convolution(x, filters=np.random.random((3, 3, 8, 4)).astype(np.float32), strides=[1, 1], padding='VALID', dilations=[1, 1])

return tf.keras.Model(inputs=_in, outputs=x)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
Expand All @@ -75,7 +90,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
cs = cosine_similarity(out_float.numpy(), out_quant.numpy())
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check: {cs}')

self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, Conv2D)) == 4,
self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, Conv2D)) == 12,
"Not all conv functions were substituted.")
self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, DepthwiseConv2D)) == 2,
"Not all dw-conv functions were substituted.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
# FIXME this doesn't test anything, the number of quantized convs in the network is exactly 0. Even if it
# looked at correct layers it hardly checks anything.
self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, layers.Conv2D)]) < len([l for l in float_model.layers if isinstance(l, layers.Conv2D)]), msg=f'fail number of layers should decrease!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')
Expand All @@ -75,6 +77,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!')


class ThreeConv2DCollapsingTest(BaseConv2DCollapsingTest):
def __init__(self, unit_test):
super().__init__(unit_test)
Expand Down Expand Up @@ -107,9 +110,35 @@ def create_networks(self):

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
super().compare(quantized_model, float_model, input_x, quantization_info)
for layer in quantized_model.layers:
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)]
self.unit_test.assertTrue(len(convs) == 1)
for layer in convs:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')


class FuncConvCollapsingTest(FourConv2DCollapsingTest):
def create_networks(self):
# Tests the combination of functional conv to Conv2D substitution with linear collapsing
# (in case of default values, tf layer doesn't contain these attributes, and they must be added explicitly
# to node's attributes dict, which is not covered by substitution test)
h, w, c = self.get_input_shapes()[0][1:]
inputs = layers.Input(shape=(h, w, c))
x = tf.nn.conv2d(inputs, tf.random.uniform((3, 3, c, 16)), 1, 'SAME')
x = tf.nn.convolution(x, tf.random.uniform((1, 1, 16, 8)))
x = tf.nn.relu(x)
x = tf.nn.convolution(x, tf.random.uniform((3, 3, 8, 32)))
y = tf.nn.conv2d(x, tf.random.uniform((1, 1, 32, 4)), 1, 'VALID')
return tf.keras.models.Model(inputs=inputs, outputs=y)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)]
self.unit_test.assertTrue(len(convs) == 2)

y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')


class SixConv2DCollapsingTest(BaseConv2DCollapsingTest):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
from tests.keras_tests.feature_networks_tests.feature_networks.input_scaling_test import InputScalingDenseTest, \
InputScalingConvTest, InputScalingDWTest, InputScalingZeroPadTest
from tests.keras_tests.feature_networks_tests.feature_networks.linear_collapsing_test import TwoConv2DCollapsingTest, \
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest, \
FuncConvCollapsingTest
from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \
LUTActivationQuantizerTest
from tests.keras_tests.feature_networks_tests.feature_networks.manual_bit_selection import ManualBitWidthSelectionTest, \
Expand Down Expand Up @@ -605,6 +606,7 @@ def test_linear_collapsing(self):
FourConv2DCollapsingTest(self).run_test()
SixConv2DCollapsingTest(self).run_test()
Op2DAddConstCollapsingTest(self).run_test()
FuncConvCollapsingTest(self).run_test()

def test_const_quantization(self):
c = (np.ones((32, 32, 16)) + np.random.random((32, 32, 16))).astype(np.float32)
Expand Down

0 comments on commit efd3310

Please sign in to comment.