-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CPU] Introduce FullyConnectedQuantized op and bias fusing
- Loading branch information
1 parent
993f117
commit 983cfbe
Showing
34 changed files
with
744 additions
and
132 deletions.
There are no files selected for viewing
46 changes: 46 additions & 0 deletions
46
src/common/transformations/include/ov_ops/fully_connected.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
class TRANSFORMATIONS_API FullyConnected : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("FullyConnected", "ie_internal_opset"); | ||
|
||
FullyConnected() = default; | ||
|
||
FullyConnected(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::Output<Node>& bias, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
FullyConnected(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
ov::element::Type get_output_type() const { | ||
return m_output_type; | ||
} | ||
|
||
protected: | ||
ov::element::Type m_output_type; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
48 changes: 48 additions & 0 deletions
48
src/common/transformations/include/ov_ops/fully_connected_quantized.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "ov_ops/fully_connected.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
class TRANSFORMATIONS_API FullyConnectedQuantized : public ov::op::internal::FullyConnected { | ||
public: | ||
OPENVINO_OP("FullyConnectedQuantized", "gpu_opset"); | ||
|
||
FullyConnectedQuantized() = default; | ||
|
||
FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& bias, | ||
const ov::Output<Node>& dequantization_scales, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& dequantization_scales, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
ov::element::Type get_output_type() const { | ||
return m_output_type; | ||
} | ||
|
||
protected: | ||
ov::element::Type m_output_type; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
class TRANSFORMATIONS_API Placeholder : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("Placeholder", "ie_internal_opset"); | ||
|
||
Placeholder(); | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
void validate_and_infer_types() override; | ||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
22 changes: 22 additions & 0 deletions
22
...common/transformations/include/transformations/op_conversions/convert_fc_to_quantized.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API ConvertFullyConnectedToFullyConnectedQuantized; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
class ov::pass::ConvertFullyConnectedToFullyConnectedQuantized : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ConvertFullyConnectedToFullyConnectedQuantized", "0"); | ||
ConvertFullyConnectedToFullyConnectedQuantized(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "ov_ops/fully_connected.hpp" | ||
|
||
#include <memory> | ||
|
||
#include "matmul_shape_inference.hpp" | ||
#include "ov_ops/placeholder.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
FullyConnected::FullyConnected(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::Output<Node>& bias, | ||
const ov::element::Type output_type) | ||
: Op({A, B, bias}), | ||
m_output_type(output_type) { | ||
validate_and_infer_types(); | ||
} | ||
|
||
FullyConnected::FullyConnected(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::element::Type output_type) | ||
: FullyConnected(A, B, std::make_shared<Placeholder>()) {} | ||
|
||
std::shared_ptr<ov::Node> FullyConnected::clone_with_new_inputs(const ov::OutputVector& new_args) const { | ||
check_new_args_count(this, new_args); | ||
|
||
return std::make_shared<FullyConnected>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type); | ||
} | ||
|
||
void FullyConnected::validate_and_infer_types() { | ||
const auto input_size = get_input_size(); | ||
NODE_VALIDATION_CHECK(this, | ||
input_size >= 3, | ||
"Number of inputs is incorrect. Current value is: ", | ||
input_size, | ||
", expected at least 3."); | ||
|
||
ov::op::v0::MatMul op; | ||
op.set_transpose_a(false); | ||
op.set_transpose_b(true); | ||
|
||
auto out_shapes = | ||
ov::op::v0::shape_infer(&op, | ||
std::vector<ov::PartialShape>{get_input_partial_shape(0), get_input_partial_shape(1)}); | ||
|
||
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; | ||
set_output_type(0, output_type, out_shapes[0]); | ||
} | ||
|
||
bool FullyConnected::visit_attributes(ov::AttributeVisitor& visitor) { | ||
visitor.on_attribute("output_type", m_output_type); | ||
return true; | ||
} | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
74 changes: 74 additions & 0 deletions
74
src/common/transformations/src/ov_ops/fully_connected_quantized.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "ov_ops/fully_connected_quantized.hpp" | ||
|
||
#include "matmul_shape_inference.hpp" | ||
#include "ov_ops/placeholder.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& bias, | ||
const ov::Output<Node>& dequantization_scales, | ||
const ov::element::Type output_type) | ||
: FullyConnected({X, W, bias}), | ||
m_output_type(output_type) { | ||
set_argument(3, dequantization_scales); | ||
validate_and_infer_types(); | ||
} | ||
|
||
FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& dequantization_scales, | ||
const ov::element::Type output_type) | ||
: FullyConnectedQuantized({X, W, std::make_shared<Placeholder>(), dequantization_scales}) { | ||
validate_and_infer_types(); | ||
} | ||
|
||
std::shared_ptr<ov::Node> FullyConnectedQuantized::clone_with_new_inputs(const ov::OutputVector& new_args) const { | ||
check_new_args_count(this, new_args); | ||
|
||
if (new_args.size() == 4) { | ||
return std::make_shared<FullyConnectedQuantized>(new_args.at(0), | ||
new_args.at(1), | ||
new_args.at(2), | ||
new_args.at(3), | ||
m_output_type); | ||
} | ||
|
||
return std::make_shared<FullyConnectedQuantized>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type); | ||
} | ||
|
||
void FullyConnectedQuantized::validate_and_infer_types() { | ||
const auto input_size = get_input_size(); | ||
NODE_VALIDATION_CHECK(this, | ||
input_size >= 3, | ||
"Number of inputs is incorrect. Current value is: ", | ||
input_size, | ||
", expected at least 3."); | ||
|
||
ov::op::v0::MatMul op; | ||
op.set_transpose_a(false); | ||
op.set_transpose_b(true); | ||
|
||
auto out_shapes = | ||
ov::op::v0::shape_infer(&op, | ||
std::vector<ov::PartialShape>{get_input_partial_shape(0), get_input_partial_shape(1)}); | ||
|
||
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; | ||
set_output_type(0, output_type, out_shapes[0]); | ||
} | ||
|
||
bool FullyConnectedQuantized::visit_attributes(ov::AttributeVisitor& visitor) { | ||
visitor.on_attribute("output_type", m_output_type); | ||
return true; | ||
} | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "ov_ops/placeholder.hpp" | ||
|
||
#include "transformations/rt_info/fused_names_attribute.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
Placeholder::Placeholder() : ov::op::Op() { | ||
validate_and_infer_types(); | ||
set_friendly_name(get_name()); | ||
get_rt_info().emplace(FusedNames::get_type_info_static(), FusedNames{get_friendly_name()}); | ||
} | ||
|
||
bool Placeholder::visit_attributes(ov::AttributeVisitor& visitor) { | ||
return true; | ||
} | ||
|
||
void Placeholder::validate_and_infer_types() { | ||
set_output_type(0, ov::element::undefined, ov::PartialShape{}); | ||
} | ||
|
||
std::shared_ptr<Node> Placeholder::clone_with_new_inputs(const ov::OutputVector& new_args) const { | ||
check_new_args_count(this, new_args); | ||
return std::make_shared<Placeholder>(); | ||
} | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
75 changes: 75 additions & 0 deletions
75
src/common/transformations/src/transformations/op_conversions/convert_fc_to_quantized.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/op_conversions/convert_fc_to_quantized.hpp" | ||
|
||
#include <memory> | ||
|
||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/pass/pattern/op/pattern.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
#include "ov_ops/fully_connected.hpp" | ||
#include "ov_ops/fully_connected_quantized.hpp" | ||
#include "transformations/utils/utils.hpp" | ||
|
||
ov::pass::ConvertFullyConnectedToFullyConnectedQuantized::ConvertFullyConnectedToFullyConnectedQuantized() { | ||
using namespace ov::pass::pattern; | ||
|
||
auto quantized_weights = [](const ov::Output<ov::Node>& output) { | ||
return output.get_element_type() == ov::element::i8; | ||
}; | ||
|
||
auto quantized_activations = [](const ov::Output<ov::Node>& output) { | ||
return output.get_element_type() == ov::element::u8 || output.get_element_type() == ov::element::i8; | ||
}; | ||
|
||
auto activations_m = pattern::any_input(quantized_activations); | ||
auto weights_m = wrap_type<ov::op::v0::Constant>(quantized_weights); | ||
auto bias_m = wrap_type<ov::op::v0::Constant>(); | ||
|
||
auto fully_connected_m = wrap_type<ov::op::internal::FullyConnected>({activations_m, weights_m, bias_m}); | ||
auto dequantization_scales_m = wrap_type<ov::op::v0::Constant>(); | ||
auto multiply_m = wrap_type<ov::op::v1::Multiply>({fully_connected_m, dequantization_scales_m}); | ||
|
||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
|
||
OPENVINO_ASSERT(pattern_map.count(fully_connected_m)); | ||
OPENVINO_ASSERT(pattern_map.count(multiply_m)); | ||
OPENVINO_ASSERT(pattern_map.count(dequantization_scales_m)); | ||
|
||
std::vector<std::shared_ptr<ov::Node>> result_nodes = {}; | ||
|
||
auto fc = std::dynamic_pointer_cast<ov::op::internal::FullyConnected>( | ||
pattern_map.at(fully_connected_m).get_node_shared_ptr()); | ||
auto activations = pattern_map.at(activations_m).get_node_shared_ptr(); | ||
auto weights = pattern_map.at(weights_m).get_node_shared_ptr(); | ||
auto bias = pattern_map.at(bias_m).get_node_shared_ptr(); | ||
auto multiply = pattern_map.at(multiply_m).get_node_shared_ptr(); | ||
auto dequantization_scales = pattern_map.at(dequantization_scales_m).get_node_shared_ptr(); | ||
const auto& fc_output_shape = fc->get_output_shape(0); | ||
const auto& multiply_output_shape = multiply->get_output_shape(0); | ||
|
||
if (fc_output_shape.back() != multiply_output_shape.back()) { | ||
return false; | ||
} | ||
|
||
auto fc_quantized = std::make_shared<ov::op::internal::FullyConnectedQuantized>(activations, | ||
weights, | ||
bias, | ||
dequantization_scales, | ||
fc->get_output_type()); | ||
|
||
// result_nodes.push_back(new_gather_node); | ||
fc_quantized->set_friendly_name(multiply->get_friendly_name()); | ||
ov::copy_runtime_info(m.get_matched_nodes(), result_nodes); | ||
ov::replace_node(multiply, fc_quantized); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(multiply_m, "ConvertFullyConnectedToFullyConnectedQuantized"); | ||
this->register_matcher(m, callback); | ||
} |
Oops, something went wrong.