-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CPU] Introduce FullyConnectedQuantized op and bias fusing
- Loading branch information
1 parent
9969f9f
commit 24e0a1b
Showing
40 changed files
with
975 additions
and
190 deletions.
There are no files selected for viewing
48 changes: 48 additions & 0 deletions
48
src/common/transformations/include/ov_ops/fully_connected.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
class TRANSFORMATIONS_API FullyConnected : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("FullyConnected", "ie_internal_opset"); | ||
|
||
FullyConnected() = default; | ||
|
||
FullyConnected(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::Output<Node>& bias, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
FullyConnected(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
virtual std::shared_ptr<Node> fuse_bias(const ov::Output<Node>& bias) const; | ||
|
||
ov::element::Type get_output_type() const { | ||
return m_output_type; | ||
} | ||
|
||
protected: | ||
ov::element::Type m_output_type; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
68 changes: 68 additions & 0 deletions
68
src/common/transformations/include/ov_ops/fully_connected_quantized.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "ov_ops/fully_connected.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
class TRANSFORMATIONS_API FullyConnectedQuantized : public ov::op::internal::FullyConnected { | ||
public: | ||
OPENVINO_OP("FullyConnectedQuantized", "gpu_opset"); | ||
|
||
FullyConnectedQuantized() = default; | ||
|
||
FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& bias, | ||
const ov::Output<Node>& weight_scales, | ||
const ov::Output<Node>& weight_zero_points, | ||
const ov::Output<Node>& input_scales, | ||
const ov::Output<Node>& input_zero_points, | ||
const ov::Output<Node>& output_scales, | ||
const ov::Output<Node>& output_zero_points, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& weight_scales, | ||
const ov::Output<Node>& weight_zero_points, | ||
const ov::Output<Node>& input_scales, | ||
const ov::Output<Node>& input_zero_points, | ||
const ov::Output<Node>& output_scales, | ||
const ov::Output<Node>& output_zero_points, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& bias, | ||
const ov::Output<Node>& output_scales, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& output_scales, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
std::shared_ptr<Node> fuse_bias(const ov::Output<Node>& bias) const override final; | ||
|
||
ov::element::Type get_output_type() const { | ||
return m_output_type; | ||
} | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
class TRANSFORMATIONS_API Placeholder : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("Placeholder", "ie_internal_opset"); | ||
|
||
Placeholder(); | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
void validate_and_infer_types() override; | ||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
22 changes: 22 additions & 0 deletions
22
...common/transformations/include/transformations/op_conversions/convert_fc_to_quantized.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API ConvertFullyConnectedToFullyConnectedQuantized; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
class ov::pass::ConvertFullyConnectedToFullyConnectedQuantized : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ConvertFullyConnectedToFullyConnectedQuantized", "0"); | ||
ConvertFullyConnectedToFullyConnectedQuantized(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "ov_ops/fully_connected.hpp" | ||
|
||
#include <memory> | ||
|
||
#include "matmul_shape_inference.hpp" | ||
#include "ov_ops/placeholder.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
FullyConnected::FullyConnected(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::Output<Node>& bias, | ||
const ov::element::Type output_type) | ||
: Op({A, B, bias}), | ||
m_output_type(output_type) { | ||
validate_and_infer_types(); | ||
} | ||
|
||
FullyConnected::FullyConnected(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::element::Type output_type) | ||
: FullyConnected(A, B, std::make_shared<Placeholder>(), output_type) {} | ||
|
||
std::shared_ptr<ov::Node> FullyConnected::clone_with_new_inputs(const ov::OutputVector& new_args) const { | ||
check_new_args_count(this, new_args); | ||
|
||
return std::make_shared<FullyConnected>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type); | ||
} | ||
|
||
std::shared_ptr<Node> FullyConnected::fuse_bias(const ov::Output<Node>& bias) const { | ||
return std::make_shared<FullyConnected>(input_value(0), input_value(1), bias, m_output_type); | ||
} | ||
|
||
void FullyConnected::validate_and_infer_types() { | ||
const auto input_size = get_input_size(); | ||
NODE_VALIDATION_CHECK(this, | ||
input_size >= 3, | ||
"Number of inputs is incorrect. Current value is: ", | ||
input_size, | ||
", expected at least 3."); | ||
|
||
ov::op::v0::MatMul op; | ||
op.set_transpose_a(false); | ||
op.set_transpose_b(true); | ||
|
||
auto out_shapes = | ||
ov::op::v0::shape_infer(&op, | ||
std::vector<ov::PartialShape>{get_input_partial_shape(0), get_input_partial_shape(1)}); | ||
|
||
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; | ||
set_output_type(0, output_type, out_shapes[0]); | ||
} | ||
|
||
bool FullyConnected::visit_attributes(ov::AttributeVisitor& visitor) { | ||
visitor.on_attribute("output_type", m_output_type); | ||
return true; | ||
} | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
141 changes: 141 additions & 0 deletions
141
src/common/transformations/src/ov_ops/fully_connected_quantized.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "ov_ops/fully_connected_quantized.hpp" | ||
|
||
#include "matmul_shape_inference.hpp" | ||
#include "ov_ops/placeholder.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& bias, | ||
const ov::Output<Node>& weight_scales, | ||
const ov::Output<Node>& weight_zero_points, | ||
const ov::Output<Node>& input_scales, | ||
const ov::Output<Node>& input_zero_points, | ||
const ov::Output<Node>& output_scales, | ||
const ov::Output<Node>& output_zero_points, | ||
const ov::element::Type output_type) | ||
: FullyConnected(X, W, bias, output_type) { | ||
set_argument(3, weight_scales); | ||
set_argument(4, weight_zero_points); | ||
set_argument(5, input_scales); | ||
set_argument(6, input_zero_points); | ||
set_argument(7, output_scales); | ||
set_argument(8, output_zero_points); | ||
validate_and_infer_types(); | ||
} | ||
|
||
FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& weight_scales, | ||
const ov::Output<Node>& weight_zero_points, | ||
const ov::Output<Node>& input_scales, | ||
const ov::Output<Node>& input_zero_points, | ||
const ov::Output<Node>& output_scales, | ||
const ov::Output<Node>& output_zero_points, | ||
const ov::element::Type output_type) | ||
: FullyConnectedQuantized(X, | ||
W, | ||
std::make_shared<Placeholder>(), | ||
weight_scales, | ||
weight_zero_points, | ||
input_scales, | ||
input_zero_points, | ||
output_scales, | ||
output_zero_points, | ||
output_type) {} | ||
|
||
FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& bias, | ||
const ov::Output<Node>& output_scales, | ||
const ov::element::Type output_type) | ||
: FullyConnectedQuantized(X, | ||
W, | ||
bias, | ||
std::make_shared<Placeholder>(), | ||
std::make_shared<Placeholder>(), | ||
std::make_shared<Placeholder>(), | ||
std::make_shared<Placeholder>(), | ||
output_scales, | ||
std::make_shared<Placeholder>(), | ||
output_type) {} | ||
|
||
FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& output_scales, | ||
const ov::element::Type output_type) | ||
: FullyConnectedQuantized(X, | ||
W, | ||
std::make_shared<Placeholder>(), | ||
std::make_shared<Placeholder>(), | ||
std::make_shared<Placeholder>(), | ||
std::make_shared<Placeholder>(), | ||
std::make_shared<Placeholder>(), | ||
output_scales, | ||
std::make_shared<Placeholder>(), | ||
output_type) {} | ||
|
||
std::shared_ptr<ov::Node> FullyConnectedQuantized::clone_with_new_inputs(const ov::OutputVector& new_args) const { | ||
check_new_args_count(this, new_args); | ||
|
||
return std::make_shared<FullyConnectedQuantized>(new_args.at(0), | ||
new_args.at(1), | ||
new_args.at(2), | ||
new_args.at(3), | ||
new_args.at(4), | ||
new_args.at(5), | ||
new_args.at(6), | ||
new_args.at(7), | ||
new_args.at(8), | ||
m_output_type); | ||
} | ||
|
||
std::shared_ptr<Node> FullyConnectedQuantized::fuse_bias(const ov::Output<Node>& bias) const { | ||
return std::make_shared<FullyConnectedQuantized>(input_value(0), | ||
input_value(1), | ||
bias, | ||
input_value(3), | ||
input_value(4), | ||
input_value(5), | ||
input_value(6), | ||
input_value(7), | ||
input_value(8), | ||
m_output_type); | ||
} | ||
|
||
// @todo finalize validate_and_infer_types | ||
void FullyConnectedQuantized::validate_and_infer_types() { | ||
const auto input_size = get_input_size(); | ||
NODE_VALIDATION_CHECK(this, | ||
input_size >= 3, | ||
"Number of inputs is incorrect. Current value is: ", | ||
input_size, | ||
", expected at least 3."); | ||
|
||
ov::op::v0::MatMul op; | ||
op.set_transpose_a(false); | ||
op.set_transpose_b(true); | ||
|
||
auto out_shapes = | ||
ov::op::v0::shape_infer(&op, | ||
std::vector<ov::PartialShape>{get_input_partial_shape(0), get_input_partial_shape(1)}); | ||
|
||
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; | ||
set_output_type(0, output_type, out_shapes[0]); | ||
} | ||
|
||
bool FullyConnectedQuantized::visit_attributes(ov::AttributeVisitor& visitor) { | ||
visitor.on_attribute("output_type", m_output_type); | ||
return true; | ||
} | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
Oops, something went wrong.