Skip to content

Commit

Permalink
fix tf.nn.{conv2,convolution} substitution
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Nov 24, 2024
1 parent 432ae5b commit e4adf42
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Tuple

import numpy as np
import tensorflow as tf
Expand All @@ -30,7 +31,7 @@
from model_compression_toolkit.constants import REUSE, REUSE_GROUP
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \
KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE, \
ACTIVATION, LINEAR
ACTIVATION, LINEAR, DATA_FORMAT, GROUPS


def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray:
Expand Down Expand Up @@ -136,42 +137,68 @@ def substitute(self,
conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2], ACTIVATION: LINEAR}
if len(conv_func_node.op_call_args) > 0:
Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover
if STRIDES in conv_func_node.op_call_kwargs:
strides = conv_func_node.op_call_kwargs[STRIDES]
if len(strides) == 4:
if strides[0] > 1 or strides[3] > 1:
# Non-standard strides -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[STRIDES] = strides[1:3]
else:
conv_fw_attr[STRIDES] = strides

strides = self._parse_tf_2d_kwarg(conv_func_node, STRIDES)
if strides is None:
# Non-standard strides -> skip substitution.
return graph
conv_fw_attr[STRIDES] = strides

if PADDING in conv_func_node.op_call_kwargs:
padding = conv_func_node.op_call_kwargs[PADDING]
if not isinstance(padding, str):
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[PADDING] = padding
if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None:
dilations = conv_func_node.op_call_kwargs[DILATIONS]
if isinstance(dilations, (list, tuple)) and len(dilations) == 4:
if dilations[0] > 1 or dilations[3] > 1:
# Non-standard dilations -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[DILATION_RATE] = dilations[1:3]
else:
conv_fw_attr[DILATION_RATE] = dilations

dilations = self._parse_tf_2d_kwarg(conv_func_node, DILATIONS)
if dilations is None:
# Non-standard dilations -> skip substitution.
return graph
conv_fw_attr[DILATION_RATE] = dilations

if b is None:
conv_fw_attr[USE_BIAS] = False
else:
weights[BIAS] = b

data_format = conv_func_node.op_call_kwargs.get(DATA_FORMAT, 'NHWC')
conv_fw_attr['data_format'] = {'NHWC': 'channels_last', 'NCHW': 'channels_first'}[data_format]

conv_fw_attr[GROUPS] = 1

_reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group}
conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape,
weights, Conv2D, **_reuse_params)

replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None)
return graph

def _parse_tf_2d_kwarg(self, node, key) -> Optional[Tuple[int, int]]:
"""
Extract stride/dilation param from tf node and convert it to keras format (suitable for Conv2D).
Args:
node: node
key: param key
Returns:
Parsed value or None if non-standard.
"""
v = node.op_call_kwargs.get(key)
if v is None:
return 1, 1
if isinstance(v, int):
return v, v
if len(v) == 1:
return v[0], v[0]
if len(v) == 4:
if v[0] > 1 and v[-1] > 1:
return None
else:
return v[1:3]
return tuple(v)


class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
# FIXME this doesn't test anything, the number of quantized convs in the network is exactly 0. Even if it
# looked at correct layers it hardly checks anything.
self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, layers.Conv2D)]) < len([l for l in float_model.layers if isinstance(l, layers.Conv2D)]), msg=f'fail number of layers should decrease!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')
Expand All @@ -75,6 +77,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!')


class ThreeConv2DCollapsingTest(BaseConv2DCollapsingTest):
def __init__(self, unit_test):
super().__init__(unit_test)
Expand Down Expand Up @@ -107,9 +110,44 @@ def create_networks(self):

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
super().compare(quantized_model, float_model, input_x, quantization_info)
for layer in quantized_model.layers:
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)]
self.unit_test.assertTrue(len(convs) == 1)
for layer in convs:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')


class FuncConv2DCollapsingTest(FourConv2DCollapsingTest):
def create_networks(self):
# tests the combination of functional conv to Conv2D substitution with linear collapsing
h, w, c = self.get_input_shapes()[0][1:]
inputs = layers.Input(shape=(h, w, c))
x = tf.nn.conv2d(inputs, tf.random.uniform((3, 3, c, 128)), 1, 'VALID')
x = tf.nn.conv2d(x, filters=tf.random.uniform((1, 1, 128, 64)), strides=[1], padding='SAME', dilations=1)
x = tf.nn.conv2d(x, tf.random.uniform((1, 1, 64, 64)), strides=[1, 1], padding='VALID', dilations=[1])
y = tf.nn.conv2d(x, tf.random.uniform((1, 1, 64, 4)), strides=[1, 1], padding='SAME', dilations=[1, 1])
return tf.keras.models.Model(inputs=inputs, outputs=y)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)]
self.unit_test.assertTrue(len(convs) == 1)

y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')


class FuncConvolutionCollapsingTest(FuncConv2DCollapsingTest):
def create_networks(self):
# tests the combination of functional conv to Conv2D substitution with linear collapsing
h, w, c = self.get_input_shapes()[0][1:]
inputs = layers.Input(shape=(h, w, c))
x = tf.nn.convolution(inputs, tf.random.uniform((3, 3, c, 128)))
x = tf.nn.convolution(x, filters=tf.random.uniform((1, 1, 128, 64)), strides=[1], padding='SAME', dilations=1)
x = tf.nn.convolution(x, tf.random.uniform((1, 1, 64, 64)), strides=[1, 1], padding='VALID', dilations=[1])
y = tf.nn.convolution(x, tf.random.uniform((1, 1, 64, 4)), strides=[1, 1], padding='VALID', dilations=[1, 1])
return tf.keras.models.Model(inputs=inputs, outputs=y)


class SixConv2DCollapsingTest(BaseConv2DCollapsingTest):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
from tests.keras_tests.feature_networks_tests.feature_networks.input_scaling_test import InputScalingDenseTest, \
InputScalingConvTest, InputScalingDWTest, InputScalingZeroPadTest
from tests.keras_tests.feature_networks_tests.feature_networks.linear_collapsing_test import TwoConv2DCollapsingTest, \
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest, \
FuncConv2DCollapsingTest, FuncConvolutionCollapsingTest
from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \
LUTActivationQuantizerTest
from tests.keras_tests.feature_networks_tests.feature_networks.manual_bit_selection import ManualBitWidthSelectionTest, \
Expand Down Expand Up @@ -605,6 +606,8 @@ def test_linear_collapsing(self):
FourConv2DCollapsingTest(self).run_test()
SixConv2DCollapsingTest(self).run_test()
Op2DAddConstCollapsingTest(self).run_test()
FuncConv2DCollapsingTest(self).run_test()
FuncConvolutionCollapsingTest(self).run_test()

def test_const_quantization(self):
c = (np.ones((32, 32, 16)) + np.random.random((32, 32, 16))).astype(np.float32)
Expand Down

0 comments on commit e4adf42

Please sign in to comment.