Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MarkDequantization transformation #27406

Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
dffcd9b
Make KeepConstPrecision attribute copyable
itikhono Nov 5, 2024
1a50a8d
Merge remote-tracking branch 'upstream/master' into itikhono/bug_fix/…
itikhono Nov 8, 2024
52e2e30
Merge remote-tracking branch 'upstream/master' into itikhono/bug_fix/…
itikhono Nov 12, 2024
2143e7e
Merge remote-tracking branch 'upstream/master' into itikhono/bug_fix/…
itikhono Nov 13, 2024
7555a9e
update mark dequantization transformation
itikhono Nov 13, 2024
5f237eb
add transformation callback
itikhono Nov 13, 2024
0dfabd9
try to fix a warning
itikhono Nov 13, 2024
865fd4c
revert KeepConstPrecision change
itikhono Nov 13, 2024
7cd9c24
Merge branch 'itikhono/bug_fix/keep_const_precision_attr' of https://…
itikhono Nov 14, 2024
fa6b0ec
align the current behavior with the previous implementation
itikhono Nov 15, 2024
1202ab2
fix tests
itikhono Nov 15, 2024
c93f7a7
add precision check
itikhono Nov 17, 2024
e7d7c5f
Merge branch 'master' into itikhono/bug_fix/keep_const_precision_attr
itikhono Nov 17, 2024
189153f
fix issue on gpu, docs, refactoring
itikhono Nov 19, 2024
48b5694
Merge branch 'master' into itikhono/bug_fix/keep_const_precision_attr
itikhono Nov 19, 2024
06f1c22
remove the dq model pass, leave the separate matchers only
itikhono Nov 19, 2024
1c7a72e
fix Opattern::op::Or logic
itikhono Nov 21, 2024
416d610
fixed the marking on gpu
itikhono Nov 21, 2024
7f2e1de
Merge branch 'master' into itikhono/bug_fix/keep_const_precision_attr
itikhono Nov 26, 2024
553f2b6
resolve review comments
itikhono Nov 30, 2024
570e84f
Merge branch 'itikhono/bug_fix/keep_const_precision_attr' of
itikhono Nov 30, 2024
6853a63
Merge branch 'master' into itikhono/bug_fix/keep_const_precision_attr
itikhono Nov 30, 2024
d2786b7
Merge branch 'master' into itikhono/bug_fix/keep_const_precision_attr
itikhono Dec 11, 2024
3944b9e
Merge branch 'master' into itikhono/bug_fix/keep_const_precision_attr
itikhono Dec 13, 2024
6d6749b
codestyle
itikhono Dec 13, 2024
ad10166
Merge branch 'master' into itikhono/bug_fix/keep_const_precision_attr
itikhono Dec 16, 2024
b4b3513
fix a warning
itikhono Dec 16, 2024
1fc379b
Merge branch 'itikhono/bug_fix/keep_const_precision_attr' of https://…
itikhono Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ auto defaultPrecisions =
useLpt ? ov::pass::low_precision::precision_set::get_int8_support() : std::vector<ov::element::Type>{};
if (useLpt) {
// disable constant folding on dequantization subgraphs so they can be processed by LPT
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(defaultPrecisions);
manager.register_pass<ov::pass::MarkDequantization>(defaultPrecisions);
}

// OpenVINO common transformations happen here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,44 @@
#include "transformations/rt_info/keep_const_precision.hpp"

#include "common_test_utils/ov_test_utils.hpp"
#include "transformations/convert_precision.hpp"

using namespace ov;

TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) {
TEST_F(TransformationTestsF, KeepConstPrecision) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also compare nodes' rt_info in this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done
but it looks like the rt_info comparison doesn't work, the same with other tests, I will double check

{
auto lp_const = std::make_shared<opset10::Constant>(element::u4, Shape{27}, 1);

const auto target_shape = std::make_shared<opset10::Constant>(ov::element::i64, ov::Shape{3}, 3);
auto reshape = std::make_shared<opset10::Reshape>(lp_const, target_shape, false);

auto second_convert = std::make_shared<opset10::Convert>(reshape, element::f32);
auto zero_point = opset10::Constant::create(element::f32, Shape{}, {127});
auto subtract = std::make_shared<opset10::Subtract>(second_convert, zero_point);
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
auto stub_op = std::make_shared<opset10::Relu>(multiply);
model = std::make_shared<Model>(stub_op, ParameterVector{});
}

manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u4});
manager.register_pass<pass::ConstantFolding>();
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u4});
manager.register_pass<pass::ConvertPrecision>(ov::element::u4, ov::element::u8, type_to_fuse_map{}, false, false);

{
auto lp_const = std::make_shared<opset10::Constant>(element::u4, Shape{3, 3, 3}, 1);
auto second_convert = std::make_shared<opset10::Convert>(lp_const, element::f32);
auto zero_point = opset10::Constant::create(element::f32, Shape{}, {127});
auto subtract = std::make_shared<opset10::Subtract>(second_convert, zero_point);
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
auto stub_op = std::make_shared<opset10::Relu>(multiply);
model_ref = std::make_shared<Model>(stub_op, ParameterVector{});
}
}

TEST_F(TransformationTestsF, MarkDequantizationTransformation) {
// Input graph:
//
// Parameter
Expand All @@ -37,7 +71,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) {
// \ /
// Convolution
//
// After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph
// After MarkDequantization all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// All 'Convert(DCF)' nodes from above graph are marked with 'DisableConstantFolding' attribute
// Weights and zero points are marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -82,7 +116,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) {
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::ConstantFolding>();

{
Expand Down Expand Up @@ -138,7 +173,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) {
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationNoZeroPoint) {
// Input graph:
//
// Parameter
Expand All @@ -158,7 +193,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
// \ /
// Convolution
//
// After MarkDequantizationSubgraph all Multiply nodes from above graph
// After MarkDequantization all Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also 'Convert(DCF)' node from above graph is marked with 'DisableConstantFolding' attribute
// Weights node is marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -197,7 +232,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::ConstantFolding>();

{
Expand Down Expand Up @@ -242,7 +278,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPointFP16) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationNoZeroPointFP16) {
// Input graph:
//
// Parameter
Expand All @@ -262,7 +298,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
// \ /
// Convolution
//
// After MarkDequantizationSubgraph all Multiply nodes from above graph
// After MarkDequantization all Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also 'Convert(DCF)' node from above graph is marked with 'DisableConstantFolding' attribute
// Weights node is marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -305,9 +341,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::DisableDecompressionConvertConstantFolding>();
manager.register_pass<pass::ConstantFolding>();
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});

{
auto parameter = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 16, 14, 14});
Expand Down Expand Up @@ -355,7 +390,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstantWeights) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationNotConstantWeights) {
// Input graph:
//
// Parameter
Expand All @@ -378,7 +413,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstant
// \ /
// Convolution
//
// After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph
// After MarkDequantization all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also all 'Convert(DCF)' nodes from above graph are marked with 'DisableConstantFolding' attribute
// Weights and zero point nodes are marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -426,7 +461,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstant
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::ConstantFolding>();

{
Expand Down Expand Up @@ -481,7 +517,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstant
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationFoldSubConst) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationFoldSubConst) {
// Input graph: After transformation:
//
// Constant Constant Constant
Expand All @@ -495,7 +531,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationFoldSubCons
// | / \ /
// Multiply Multiply
//
// After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph
// After MarkDequantization all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also all 'Convert(DCF)' node before weights is marked with 'DisableConstantFolding' attribute
// but Convert before Dequantization Sub const isn't because fold_subtract_const is set to true
Expand All @@ -512,7 +548,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationFoldSubCons
model = std::make_shared<ov::Model>(ov::OutputVector{multiply});
}

manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8}, true);
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8}, true);
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8}, true);
manager.register_pass<pass::ConstantFolding>();

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,77 @@

#pragma once

#include <utility>

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {
/**
* @ingroup ov_transformation_common_api
*
* @brief MarkDequantization matches Dequantization subgraphs and marks Subtract and Multiply nodes
* with the dequantization attribute. Also if Convert nodes are part of the subgraph they might be marked
* with the disable_const_folding attribute.
*
* If Convert -> Reshape/Unsqueeze are part of the Dequantization subraph, Convert and Reshape/Unsqueeze
* nodes will be swapped to eliminate Reshape/Unsqueeze in the next ConstantFolding.
*
* Dequantization subgraph may have two forms: with and without Subtract.
* ZeroPoints and Scale might be present as subgraphs and include Convert ops.
*
* Input ZeroPoints
* │ │
* ▼ ▼
* Convert (opt) Reshape/Unsqueeze
* │ │
* ▼ ▼ Scale Input Scale
* Subtract │ │ │
* │ ▼ ▼ ▼
* │ (opt) Reshape/Unsqueeze Convert (opt) Reshape/Unsqueeze
* │ │ │ │
* ▼ ▼ ▼ ▼
* Multiply Multiply
*
*/
class TRANSFORMATIONS_API MarkDequantization : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("MarkDequantization", "0");
explicit MarkDequantization(const element::TypeVector& precisions,
bool fold_subtract_const = false,
bool fold_multiply_const = true);
};

/**
* @ingroup ov_transformation_common_api
* @brief MarkDequantizationSubgraph marks dequantization subgraph, that is:
* Convert->Subtract(optional)->Multiply
* in two ways:
* - first Convert is marked with DisableConstantFolding attribute, also if Subtract is present
* and its second input is a Convert - that Convert is marked with DisableConstantFolding as well,
* - Subtract and Multiply are marked with 'DequantizationNode' attribute
*
* @brief KeepConstsPrecision matches Dequantization subgraphs and if Input/ZeroPoints/Scale are Constants
* they might be marked with keep_const_precision attribute.
*
* Dequantization subgraph may have two forms: with and without Subtract.
*
* Input
* │
* ▼
* Convert ZeroPoints
* │ │
* ▼ ▼ Input
* Subtract │
* │ ▼
* │ Scale Convert Scale
* │ │ │ │
* ▼ ▼ ▼ ▼
* Multiply Multiply
*
*/
class TRANSFORMATIONS_API MarkDequantizationSubgraph : public MatcherPass {
class TRANSFORMATIONS_API KeepConstsPrecision : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("MarkDequantizationSubgraph", "0");
MarkDequantizationSubgraph(const element::TypeVector& precisions,
const bool fold_subtract_const = false,
const bool disable_fold_multiply_const = false);
OPENVINO_RTTI("KeepConstsPrecision", "0");
explicit KeepConstsPrecision(const element::TypeVector& precisions,
bool fold_subtract_const = false,
bool fold_multiply_const = true);
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,27 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
manager.set_per_pass_validation(false);
using namespace ov::pass;
REGISTER_PASS(manager, InitNodeInfo)
if (m_low_precision_enabled) {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4});
}
REGISTER_PASS(manager, EliminateConvert)
itikhono marked this conversation as resolved.
Show resolved Hide resolved
if (!m_use_shapes) {
manager.register_pass<ov::pass::DisableShapeOfConstantFolding>();
}

if (m_low_precision_enabled) {
manager.register_pass<ov::pass::MarkDequantization>(
element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4});
}

// RemoveConcatZeroDimInput and RemoveMultiSubGraphOpDanglingParamsResults
// should be performed before first ConstantFolding call.
// should be performed before first !ConstantFolding! call.
// The passes can deteach graph branches where zero dimesion is calculated.
// Zero dimensions in shape causes creation empty tensors, which are incorrect during CF.
// In particular, if zero dim tensor is consumed in body of MultiSubGraphOp
// RemoveConcatZeroDimInput and RemoveMultiSubGraphOpDanglingParamsResults should be called together.
using namespace ov::pass;
REGISTER_PASS(manager, EliminateConvert)
REGISTER_PASS(manager, EliminateScatterUpdate)
REGISTER_PASS(manager, RemoveConcatZeroDimInput)
REGISTER_PASS(manager, EliminateLoopInputsOutputs);
REGISTER_PASS(manager, Validate)

// todo: ticket 96960
// the order EliminateDuplicateTIInputs and RemoveMultiSubGraphOpDanglingParamsResults is important
// it looks like we need to combine these transformations into one.
Expand Down
Loading
Loading