Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tf.nn.{conv2d,convolution} substitution #1275

Merged
merged 2 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Tuple

import numpy as np
import tensorflow as tf
Expand All @@ -30,7 +31,7 @@
from model_compression_toolkit.constants import REUSE, REUSE_GROUP
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \
KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE, \
ACTIVATION, LINEAR
ACTIVATION, LINEAR, DATA_FORMAT, GROUPS


def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray:
Expand Down Expand Up @@ -136,42 +137,68 @@ def substitute(self,
conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2], ACTIVATION: LINEAR}
if len(conv_func_node.op_call_args) > 0:
Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover
if STRIDES in conv_func_node.op_call_kwargs:
strides = conv_func_node.op_call_kwargs[STRIDES]
if len(strides) == 4:
if strides[0] > 1 or strides[3] > 1:
# Non-standard strides -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[STRIDES] = strides[1:3]
else:
conv_fw_attr[STRIDES] = strides

strides = self._parse_tf_2d_kwarg(conv_func_node, STRIDES)
if strides is None:
# Non-standard strides -> skip substitution.
return graph
conv_fw_attr[STRIDES] = strides

if PADDING in conv_func_node.op_call_kwargs:
padding = conv_func_node.op_call_kwargs[PADDING]
if not isinstance(padding, str):
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[PADDING] = padding
if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None:
dilations = conv_func_node.op_call_kwargs[DILATIONS]
if isinstance(dilations, (list, tuple)) and len(dilations) == 4:
if dilations[0] > 1 or dilations[3] > 1:
# Non-standard dilations -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[DILATION_RATE] = dilations[1:3]
else:
conv_fw_attr[DILATION_RATE] = dilations

dilations = self._parse_tf_2d_kwarg(conv_func_node, DILATIONS)
if dilations is None:
# Non-standard dilations -> skip substitution.
return graph
conv_fw_attr[DILATION_RATE] = dilations

if b is None:
conv_fw_attr[USE_BIAS] = False
else:
weights[BIAS] = b

data_format = conv_func_node.op_call_kwargs.get(DATA_FORMAT, 'NHWC')
conv_fw_attr['data_format'] = {'NHWC': 'channels_last', 'NCHW': 'channels_first'}[data_format]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use constants from the keras constants


conv_fw_attr[GROUPS] = 1

_reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group}
conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape,
weights, Conv2D, **_reuse_params)

replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None)
return graph

def _parse_tf_2d_kwarg(self, node, key) -> Optional[Tuple[int, int]]:
"""
Extract stride/dilation param from tf node and convert it to keras format (suitable for Conv2D).

Args:
node: node
key: param key

Returns:
Parsed value or None if non-standard.
"""
v = node.op_call_kwargs.get(key)
if v is None:
return 1, 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this way you assume the defaults. why not return None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's intentional. None wouldn't do, we need to fill in an explicit default. This method is specific to tf stride & dilation

if isinstance(v, int):
return v, v
if len(v) == 1:
return v[0], v[0]
if len(v) == 4:
if v[0] > 1 and v[-1] > 1:
return None
else:
return v[1:3]
return tuple(v)


class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
# FIXME this doesn't test anything, the number of quantized convs in the network is exactly 0. Even if it
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then why not remove it?

# looked at correct layers it hardly checks anything.
self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, layers.Conv2D)]) < len([l for l in float_model.layers if isinstance(l, layers.Conv2D)]), msg=f'fail number of layers should decrease!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')
Expand All @@ -75,6 +77,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!')


class ThreeConv2DCollapsingTest(BaseConv2DCollapsingTest):
def __init__(self, unit_test):
super().__init__(unit_test)
Expand Down Expand Up @@ -107,9 +110,44 @@ def create_networks(self):

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
super().compare(quantized_model, float_model, input_x, quantization_info)
for layer in quantized_model.layers:
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)]
self.unit_test.assertTrue(len(convs) == 1)
for layer in convs:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')


class FuncConv2DCollapsingTest(FourConv2DCollapsingTest):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests seem redunant, as the there's no difference between these tests and the ones for Conv2D layer. wht not just test the substitution of these layers to Conv2D?

def create_networks(self):
# tests the combination of functional conv to Conv2D substitution with linear collapsing
h, w, c = self.get_input_shapes()[0][1:]
inputs = layers.Input(shape=(h, w, c))
x = tf.nn.conv2d(inputs, tf.random.uniform((3, 3, c, 128)), 1, 'VALID')
x = tf.nn.conv2d(x, filters=tf.random.uniform((1, 1, 128, 64)), strides=[1], padding='SAME', dilations=1)
x = tf.nn.conv2d(x, tf.random.uniform((1, 1, 64, 64)), strides=[1, 1], padding='VALID', dilations=[1])
y = tf.nn.conv2d(x, tf.random.uniform((1, 1, 64, 4)), strides=[1, 1], padding='SAME', dilations=[1, 1])
return tf.keras.models.Model(inputs=inputs, outputs=y)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)]
self.unit_test.assertTrue(len(convs) == 1)

y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')


class FuncConvolutionCollapsingTest(FuncConv2DCollapsingTest):
def create_networks(self):
# tests the combination of functional conv to Conv2D substitution with linear collapsing
h, w, c = self.get_input_shapes()[0][1:]
inputs = layers.Input(shape=(h, w, c))
x = tf.nn.convolution(inputs, tf.random.uniform((3, 3, c, 128)))
x = tf.nn.convolution(x, filters=tf.random.uniform((1, 1, 128, 64)), strides=[1], padding='SAME', dilations=1)
x = tf.nn.convolution(x, tf.random.uniform((1, 1, 64, 64)), strides=[1, 1], padding='VALID', dilations=[1])
y = tf.nn.convolution(x, tf.random.uniform((1, 1, 64, 4)), strides=[1, 1], padding='VALID', dilations=[1, 1])
return tf.keras.models.Model(inputs=inputs, outputs=y)


class SixConv2DCollapsingTest(BaseConv2DCollapsingTest):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
from tests.keras_tests.feature_networks_tests.feature_networks.input_scaling_test import InputScalingDenseTest, \
InputScalingConvTest, InputScalingDWTest, InputScalingZeroPadTest
from tests.keras_tests.feature_networks_tests.feature_networks.linear_collapsing_test import TwoConv2DCollapsingTest, \
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest, \
FuncConv2DCollapsingTest, FuncConvolutionCollapsingTest
from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \
LUTActivationQuantizerTest
from tests.keras_tests.feature_networks_tests.feature_networks.manual_bit_selection import ManualBitWidthSelectionTest, \
Expand Down Expand Up @@ -605,6 +606,8 @@ def test_linear_collapsing(self):
FourConv2DCollapsingTest(self).run_test()
SixConv2DCollapsingTest(self).run_test()
Op2DAddConstCollapsingTest(self).run_test()
FuncConv2DCollapsingTest(self).run_test()
FuncConvolutionCollapsingTest(self).run_test()

def test_const_quantization(self):
c = (np.ones((32, 32, 16)) + np.random.random((32, 32, 16))).astype(np.float32)
Expand Down
Loading