-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix tf.nn.{conv2d,convolution} substitution #1275
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add PR description
if b is None: | ||
conv_fw_attr[USE_BIAS] = False | ||
else: | ||
weights[BIAS] = b | ||
|
||
data_format = conv_func_node.op_call_kwargs.get(DATA_FORMAT, 'NHWC') | ||
conv_fw_attr['data_format'] = {'NHWC': 'channels_last', 'NCHW': 'channels_first'}[data_format] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use constants from the keras constants
""" | ||
v = node.op_call_kwargs.get(key) | ||
if v is None: | ||
return 1, 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this way you assume the defaults. why not return None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's intentional. None wouldn't do, we need to fill in an explicit default. This method is specific to tf stride & dilation
@@ -54,6 +54,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= | |||
y = float_model.predict(input_x) | |||
y_hat = quantized_model.predict(input_x) | |||
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') | |||
# FIXME this doesn't test anything, the number of quantized convs in the network is exactly 0. Even if it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then why not remove it?
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!') | ||
|
||
|
||
class FuncConv2DCollapsingTest(FourConv2DCollapsingTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests seem redunant, as the there's no difference between these tests and the ones for Conv2D layer. wht not just test the substitution of these layers to Conv2D?
Pull Request Description:
Fixed tf conv substitution to handle attrs with default values that were not passed explicitly, and convert tf-specific format to keras compatible values.
Checklist before requesting a review: