Skip to content

Commit

Permalink
Update MarkDequantization transformation (openvinotoolkit#27406)
Browse files Browse the repository at this point in the history
### Details:

**Original issue description:**

> Inside MarkDequantizationSubgraph, we try to find the dequantization
subgraph and mark it with disable const folding and keep const precision
attributes.
> The dequantization subgraph has a relaxed structure and allows
"any_input" as the input of Convert operation.
> 
> from MarkDequantizationSubgraph:
> <img width="235" alt="{A09C9C08-03C1-4A48-8EAD-60D7EF76A4B5}"
src="https://github.com/user-attachments/assets/d07063c8-6872-4946-ae9a-397eb2c41d87">
> 
> in the current case this is
>  (any input: Constant->Reshape) -> (necessary Convert) 
> 
> Constant->Reshape will be ConstFolded. ConstFolding doesn't run data
copying in case of Reshape operation, so this is valid scenario.
> 
> MarkDequantizationSubgraph transformation marks Reshape op with
"KeepConstPrecision" attribute.
> But after ConstFolding, KeepConstPrecision attr won't be copied to
resulting constant because of
> <img width="148" alt="{FE3E166D-17BD-410F-A08F-4502FB9BB3D0}"
src="https://github.com/user-attachments/assets/3dac34b9-9a79-4359-b4f8-87e69633ffb8">
> and the whole dequantization subgraph will be const folded.

**Changes:**

- MarkDequantizationSubgraph logic was split into 2 transformations:
MarkDequantization, KeepConstPrecision


### Tickets:
 - *CVS-156576*
 - *CVS-156329*
  • Loading branch information
itikhono authored and 11happy committed Dec 23, 2024
1 parent 8771157 commit 8658129
Show file tree
Hide file tree
Showing 9 changed files with 299 additions and 130 deletions.
2 changes: 1 addition & 1 deletion docs/articles_en/assets/snippets/lpt_intel_cpu_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ auto defaultPrecisions =
useLpt ? ov::pass::low_precision::precision_set::get_int8_support() : std::vector<ov::element::Type>{};
if (useLpt) {
// disable constant folding on dequantization subgraphs so they can be processed by LPT
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(defaultPrecisions);
manager.register_pass<ov::pass::MarkDequantization>(defaultPrecisions);
}

// OpenVINO common transformations happen here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,51 @@
#include "transformations/rt_info/keep_const_precision.hpp"

#include "common_test_utils/ov_test_utils.hpp"
#include "transformations/convert_precision.hpp"

using namespace ov;

TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) {
TEST_F(TransformationTestsF, KeepConstPrecision) {
{
auto lp_const = std::make_shared<opset10::Constant>(element::u4, Shape{27}, 1);

const auto target_shape = std::make_shared<opset10::Constant>(ov::element::i64, ov::Shape{3}, 3);
auto reshape = std::make_shared<opset10::Reshape>(lp_const, target_shape, false);

auto second_convert = std::make_shared<opset10::Convert>(reshape, element::f32);
auto zero_point = opset10::Constant::create(element::f32, Shape{}, {127});
auto subtract = std::make_shared<opset10::Subtract>(second_convert, zero_point);
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
auto stub_op = std::make_shared<opset10::Relu>(multiply);
model = std::make_shared<Model>(stub_op, ParameterVector{});
}

manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u4});
manager.register_pass<pass::ConstantFolding>();
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u4});
manager.register_pass<pass::ConvertPrecision>(ov::element::u4, ov::element::u8, type_to_fuse_map{}, false, false);

{
auto lp_const = std::make_shared<opset10::Constant>(element::u4, Shape{3, 3, 3}, 1);
auto second_convert = std::make_shared<opset10::Convert>(lp_const, element::f32);
auto zero_point = opset10::Constant::create(element::f32, Shape{}, {127});
auto subtract = std::make_shared<opset10::Subtract>(second_convert, zero_point);
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
auto stub_op = std::make_shared<opset10::Relu>(multiply);
model_ref = std::make_shared<Model>(stub_op, ParameterVector{});

mark_as_dequantization_node(subtract);
mark_as_dequantization_node(multiply);
enable_keep_const_precision(lp_const);
ov::pass::disable_constant_folding(second_convert);
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationTransformation) {
// Input graph:
//
// Parameter
Expand All @@ -37,7 +78,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) {
// \ /
// Convolution
//
// After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph
// After MarkDequantization all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// All 'Convert(DCF)' nodes from above graph are marked with 'DisableConstantFolding' attribute
// Weights and zero points are marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -82,7 +123,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) {
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::ConstantFolding>();

{
Expand Down Expand Up @@ -138,7 +180,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) {
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationNoZeroPoint) {
// Input graph:
//
// Parameter
Expand All @@ -158,7 +200,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
// \ /
// Convolution
//
// After MarkDequantizationSubgraph all Multiply nodes from above graph
// After MarkDequantization all Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also 'Convert(DCF)' node from above graph is marked with 'DisableConstantFolding' attribute
// Weights node is marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -197,7 +239,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::ConstantFolding>();

{
Expand Down Expand Up @@ -242,7 +285,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPointFP16) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationNoZeroPointFP16) {
// Input graph:
//
// Parameter
Expand All @@ -262,7 +305,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
// \ /
// Convolution
//
// After MarkDequantizationSubgraph all Multiply nodes from above graph
// After MarkDequantization all Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also 'Convert(DCF)' node from above graph is marked with 'DisableConstantFolding' attribute
// Weights node is marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -305,9 +348,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::DisableDecompressionConvertConstantFolding>();
manager.register_pass<pass::ConstantFolding>();
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});

{
auto parameter = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 16, 14, 14});
Expand Down Expand Up @@ -355,7 +397,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstantWeights) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationNotConstantWeights) {
// Input graph:
//
// Parameter
Expand All @@ -378,7 +420,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstant
// \ /
// Convolution
//
// After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph
// After MarkDequantization all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also all 'Convert(DCF)' nodes from above graph are marked with 'DisableConstantFolding' attribute
// Weights and zero point nodes are marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -426,7 +468,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstant
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::ConstantFolding>();

{
Expand Down Expand Up @@ -481,7 +524,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstant
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationFoldSubConst) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationFoldSubConst) {
// Input graph: After transformation:
//
// Constant Constant Constant
Expand All @@ -495,7 +538,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationFoldSubCons
// | / \ /
// Multiply Multiply
//
// After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph
// After MarkDequantization all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also all 'Convert(DCF)' node before weights is marked with 'DisableConstantFolding' attribute
// but Convert before Dequantization Sub const isn't because fold_subtract_const is set to true
Expand All @@ -512,7 +555,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationFoldSubCons
model = std::make_shared<ov::Model>(ov::OutputVector{multiply});
}

manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8}, true);
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8}, true);
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8}, true);
manager.register_pass<pass::ConstantFolding>();

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,77 @@

#pragma once

#include <utility>

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {
/**
* @ingroup ov_transformation_common_api
*
* @brief MarkDequantization matches Dequantization subgraphs and marks Subtract and Multiply nodes
* with the dequantization attribute. Also if Convert nodes are part of the subgraph they might be marked
* with the disable_const_folding attribute.
*
* If Convert -> Reshape/Unsqueeze are part of the Dequantization subraph, Convert and Reshape/Unsqueeze
* nodes will be swapped to eliminate Reshape/Unsqueeze in the next ConstantFolding.
*
* Dequantization subgraph may have two forms: with and without Subtract.
* ZeroPoints and Scale might be present as subgraphs and include Convert ops.
*
* Input ZeroPoints
* │ │
* ▼ ▼
* Convert (opt) Reshape/Unsqueeze
* │ │
* ▼ ▼ Scale Input Scale
* Subtract │ │ │
* │ ▼ ▼ ▼
* │ (opt) Reshape/Unsqueeze Convert (opt) Reshape/Unsqueeze
* │ │ │ │
* ▼ ▼ ▼ ▼
* Multiply Multiply
*
*/
class TRANSFORMATIONS_API MarkDequantization : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("MarkDequantization", "0");
explicit MarkDequantization(const element::TypeVector& precisions,
bool fold_subtract_const = false,
bool fold_multiply_const = true);
};

/**
* @ingroup ov_transformation_common_api
* @brief MarkDequantizationSubgraph marks dequantization subgraph, that is:
* Convert->Subtract(optional)->Multiply
* in two ways:
* - first Convert is marked with DisableConstantFolding attribute, also if Subtract is present
* and its second input is a Convert - that Convert is marked with DisableConstantFolding as well,
* - Subtract and Multiply are marked with 'DequantizationNode' attribute
*
* @brief KeepConstsPrecision matches Dequantization subgraphs and if Input/ZeroPoints/Scale are Constants
* they might be marked with keep_const_precision attribute.
*
* Dequantization subgraph may have two forms: with and without Subtract.
*
* Input
* │
* ▼
* Convert ZeroPoints
* │ │
* ▼ ▼ Input
* Subtract │
* │ ▼
* │ Scale Convert Scale
* │ │ │ │
* ▼ ▼ ▼ ▼
* Multiply Multiply
*
*/
class TRANSFORMATIONS_API MarkDequantizationSubgraph : public MatcherPass {
class TRANSFORMATIONS_API KeepConstsPrecision : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("MarkDequantizationSubgraph", "0");
MarkDequantizationSubgraph(const element::TypeVector& precisions,
const bool fold_subtract_const = false,
const bool disable_fold_multiply_const = false);
OPENVINO_RTTI("KeepConstsPrecision", "0");
explicit KeepConstsPrecision(const element::TypeVector& precisions,
bool fold_subtract_const = false,
bool fold_multiply_const = true);
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
using namespace ov::pass;
REGISTER_PASS(manager, InitNodeInfo)
if (m_low_precision_enabled) {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
manager.register_pass<ov::pass::MarkDequantization>(
element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4});
}
if (!m_use_shapes) {
Expand Down
Loading

0 comments on commit 8658129

Please sign in to comment.