-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix tf.nn.{conv2d,convolution} substitution #1275
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
from typing import Optional, Tuple | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
@@ -30,7 +31,7 @@ | |
from model_compression_toolkit.constants import REUSE, REUSE_GROUP | ||
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \ | ||
KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE, \ | ||
ACTIVATION, LINEAR | ||
ACTIVATION, LINEAR, DATA_FORMAT, GROUPS | ||
|
||
|
||
def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray: | ||
|
@@ -136,42 +137,68 @@ def substitute(self, | |
conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2], ACTIVATION: LINEAR} | ||
if len(conv_func_node.op_call_args) > 0: | ||
Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover | ||
if STRIDES in conv_func_node.op_call_kwargs: | ||
strides = conv_func_node.op_call_kwargs[STRIDES] | ||
if len(strides) == 4: | ||
if strides[0] > 1 or strides[3] > 1: | ||
# Non-standard strides -> skip substitution. | ||
return graph # pragma: no cover | ||
conv_fw_attr[STRIDES] = strides[1:3] | ||
else: | ||
conv_fw_attr[STRIDES] = strides | ||
|
||
strides = self._parse_tf_2d_kwarg(conv_func_node, STRIDES) | ||
if strides is None: | ||
# Non-standard strides -> skip substitution. | ||
return graph | ||
conv_fw_attr[STRIDES] = strides | ||
|
||
if PADDING in conv_func_node.op_call_kwargs: | ||
padding = conv_func_node.op_call_kwargs[PADDING] | ||
if not isinstance(padding, str): | ||
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution. | ||
return graph # pragma: no cover | ||
conv_fw_attr[PADDING] = padding | ||
if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None: | ||
dilations = conv_func_node.op_call_kwargs[DILATIONS] | ||
if isinstance(dilations, (list, tuple)) and len(dilations) == 4: | ||
if dilations[0] > 1 or dilations[3] > 1: | ||
# Non-standard dilations -> skip substitution. | ||
return graph # pragma: no cover | ||
conv_fw_attr[DILATION_RATE] = dilations[1:3] | ||
else: | ||
conv_fw_attr[DILATION_RATE] = dilations | ||
|
||
dilations = self._parse_tf_2d_kwarg(conv_func_node, DILATIONS) | ||
if dilations is None: | ||
# Non-standard dilations -> skip substitution. | ||
return graph | ||
conv_fw_attr[DILATION_RATE] = dilations | ||
|
||
if b is None: | ||
conv_fw_attr[USE_BIAS] = False | ||
else: | ||
weights[BIAS] = b | ||
|
||
data_format = conv_func_node.op_call_kwargs.get(DATA_FORMAT, 'NHWC') | ||
conv_fw_attr['data_format'] = {'NHWC': 'channels_last', 'NCHW': 'channels_first'}[data_format] | ||
|
||
conv_fw_attr[GROUPS] = 1 | ||
|
||
_reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group} | ||
conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape, | ||
weights, Conv2D, **_reuse_params) | ||
|
||
replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None) | ||
return graph | ||
|
||
def _parse_tf_2d_kwarg(self, node, key) -> Optional[Tuple[int, int]]: | ||
""" | ||
Extract stride/dilation param from tf node and convert it to keras format (suitable for Conv2D). | ||
|
||
Args: | ||
node: node | ||
key: param key | ||
|
||
Returns: | ||
Parsed value or None if non-standard. | ||
""" | ||
v = node.op_call_kwargs.get(key) | ||
if v is None: | ||
return 1, 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this way you assume the defaults. why not return None? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's intentional. None wouldn't do, we need to fill in an explicit default. This method is specific to tf stride & dilation |
||
if isinstance(v, int): | ||
return v, v | ||
if len(v) == 1: | ||
return v[0], v[0] | ||
if len(v) == 4: | ||
if v[0] > 1 and v[-1] > 1: | ||
return None | ||
else: | ||
return v[1:3] | ||
return tuple(v) | ||
|
||
|
||
class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= | |
y = float_model.predict(input_x) | ||
y_hat = quantized_model.predict(input_x) | ||
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') | ||
# FIXME this doesn't test anything, the number of quantized convs in the network is exactly 0. Even if it | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. then why not remove it? |
||
# looked at correct layers it hardly checks anything. | ||
self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, layers.Conv2D)]) < len([l for l in float_model.layers if isinstance(l, layers.Conv2D)]), msg=f'fail number of layers should decrease!') | ||
cs = cosine_similarity(y, y_hat) | ||
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}') | ||
|
@@ -75,6 +77,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= | |
if type(layer) == layers.Conv2D: | ||
self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!') | ||
|
||
|
||
class ThreeConv2DCollapsingTest(BaseConv2DCollapsingTest): | ||
def __init__(self, unit_test): | ||
super().__init__(unit_test) | ||
|
@@ -107,9 +110,44 @@ def create_networks(self): | |
|
||
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): | ||
super().compare(quantized_model, float_model, input_x, quantization_info) | ||
for layer in quantized_model.layers: | ||
if type(layer) == layers.Conv2D: | ||
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!') | ||
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)] | ||
self.unit_test.assertTrue(len(convs) == 1) | ||
for layer in convs: | ||
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!') | ||
|
||
|
||
class FuncConv2DCollapsingTest(FourConv2DCollapsingTest): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests seem redunant, as the there's no difference between these tests and the ones for Conv2D layer. wht not just test the substitution of these layers to Conv2D? |
||
def create_networks(self): | ||
# tests the combination of functional conv to Conv2D substitution with linear collapsing | ||
h, w, c = self.get_input_shapes()[0][1:] | ||
inputs = layers.Input(shape=(h, w, c)) | ||
x = tf.nn.conv2d(inputs, tf.random.uniform((3, 3, c, 128)), 1, 'VALID') | ||
x = tf.nn.conv2d(x, filters=tf.random.uniform((1, 1, 128, 64)), strides=[1], padding='SAME', dilations=1) | ||
x = tf.nn.conv2d(x, tf.random.uniform((1, 1, 64, 64)), strides=[1, 1], padding='VALID', dilations=[1]) | ||
y = tf.nn.conv2d(x, tf.random.uniform((1, 1, 64, 4)), strides=[1, 1], padding='SAME', dilations=[1, 1]) | ||
return tf.keras.models.Model(inputs=inputs, outputs=y) | ||
|
||
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): | ||
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)] | ||
self.unit_test.assertTrue(len(convs) == 1) | ||
|
||
y = float_model.predict(input_x) | ||
y_hat = quantized_model.predict(input_x) | ||
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') | ||
cs = cosine_similarity(y, y_hat) | ||
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}') | ||
|
||
|
||
class FuncConvolutionCollapsingTest(FuncConv2DCollapsingTest): | ||
def create_networks(self): | ||
# tests the combination of functional conv to Conv2D substitution with linear collapsing | ||
h, w, c = self.get_input_shapes()[0][1:] | ||
inputs = layers.Input(shape=(h, w, c)) | ||
x = tf.nn.convolution(inputs, tf.random.uniform((3, 3, c, 128))) | ||
x = tf.nn.convolution(x, filters=tf.random.uniform((1, 1, 128, 64)), strides=[1], padding='SAME', dilations=1) | ||
x = tf.nn.convolution(x, tf.random.uniform((1, 1, 64, 64)), strides=[1, 1], padding='VALID', dilations=[1]) | ||
y = tf.nn.convolution(x, tf.random.uniform((1, 1, 64, 4)), strides=[1, 1], padding='VALID', dilations=[1, 1]) | ||
return tf.keras.models.Model(inputs=inputs, outputs=y) | ||
|
||
|
||
class SixConv2DCollapsingTest(BaseConv2DCollapsingTest): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use constants from the keras constants