Skip to content

Commit

Permalink
[CPU] Introduce FullyConnectedQuantized op and bias fusing
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorDuplensky committed Aug 28, 2024
1 parent 993f117 commit 983cfbe
Show file tree
Hide file tree
Showing 34 changed files with 744 additions and 132 deletions.
46 changes: 46 additions & 0 deletions src/common/transformations/include/ov_ops/fully_connected.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/node.hpp"
#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace op {
namespace internal {

class TRANSFORMATIONS_API FullyConnected : public ov::op::Op {
public:
OPENVINO_OP("FullyConnected", "ie_internal_opset");

FullyConnected() = default;

FullyConnected(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::Output<Node>& bias,
const ov::element::Type output_type = ov::element::undefined);

FullyConnected(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

ov::element::Type get_output_type() const {
return m_output_type;
}

protected:
ov::element::Type m_output_type;
};

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/node.hpp"
#include "openvino/op/op.hpp"
#include "ov_ops/fully_connected.hpp"

namespace ov {
namespace op {
namespace internal {

class TRANSFORMATIONS_API FullyConnectedQuantized : public ov::op::internal::FullyConnected {
public:
OPENVINO_OP("FullyConnectedQuantized", "gpu_opset");

FullyConnectedQuantized() = default;

FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& dequantization_scales,
const ov::element::Type output_type = ov::element::undefined);

FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& dequantization_scales,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

ov::element::Type get_output_type() const {
return m_output_type;
}

protected:
ov::element::Type m_output_type;
};

} // namespace internal
} // namespace op
} // namespace ov
27 changes: 27 additions & 0 deletions src/common/transformations/include/ov_ops/placeholder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace op {
namespace internal {

class TRANSFORMATIONS_API Placeholder : public ov::op::Op {
public:
OPENVINO_OP("Placeholder", "ie_internal_opset");

Placeholder();

bool visit_attributes(ov::AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
};

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API ConvertFullyConnectedToFullyConnectedQuantized;

} // namespace pass
} // namespace ov

class ov::pass::ConvertFullyConnectedToFullyConnectedQuantized : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertFullyConnectedToFullyConnectedQuantized", "0");
ConvertFullyConnectedToFullyConnectedQuantized();
};
63 changes: 63 additions & 0 deletions src/common/transformations/src/ov_ops/fully_connected.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ov_ops/fully_connected.hpp"

#include <memory>

#include "matmul_shape_inference.hpp"
#include "ov_ops/placeholder.hpp"

namespace ov {
namespace op {
namespace internal {

FullyConnected::FullyConnected(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::Output<Node>& bias,
const ov::element::Type output_type)
: Op({A, B, bias}),
m_output_type(output_type) {
validate_and_infer_types();
}

FullyConnected::FullyConnected(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::element::Type output_type)
: FullyConnected(A, B, std::make_shared<Placeholder>()) {}

std::shared_ptr<ov::Node> FullyConnected::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);

return std::make_shared<FullyConnected>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type);
}

void FullyConnected::validate_and_infer_types() {
const auto input_size = get_input_size();
NODE_VALIDATION_CHECK(this,
input_size >= 3,
"Number of inputs is incorrect. Current value is: ",
input_size,
", expected at least 3.");

ov::op::v0::MatMul op;
op.set_transpose_a(false);
op.set_transpose_b(true);

auto out_shapes =
ov::op::v0::shape_infer(&op,
std::vector<ov::PartialShape>{get_input_partial_shape(0), get_input_partial_shape(1)});

auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type;
set_output_type(0, output_type, out_shapes[0]);
}

bool FullyConnected::visit_attributes(ov::AttributeVisitor& visitor) {
visitor.on_attribute("output_type", m_output_type);
return true;
}

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ov_ops/fully_connected_quantized.hpp"

#include "matmul_shape_inference.hpp"
#include "ov_ops/placeholder.hpp"

namespace ov {
namespace op {
namespace internal {

FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& dequantization_scales,
const ov::element::Type output_type)
: FullyConnected({X, W, bias}),
m_output_type(output_type) {
set_argument(3, dequantization_scales);
validate_and_infer_types();
}

FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& dequantization_scales,
const ov::element::Type output_type)
: FullyConnectedQuantized({X, W, std::make_shared<Placeholder>(), dequantization_scales}) {
validate_and_infer_types();
}

std::shared_ptr<ov::Node> FullyConnectedQuantized::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);

if (new_args.size() == 4) {
return std::make_shared<FullyConnectedQuantized>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
m_output_type);
}

return std::make_shared<FullyConnectedQuantized>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type);
}

void FullyConnectedQuantized::validate_and_infer_types() {
const auto input_size = get_input_size();
NODE_VALIDATION_CHECK(this,
input_size >= 3,
"Number of inputs is incorrect. Current value is: ",
input_size,
", expected at least 3.");

ov::op::v0::MatMul op;
op.set_transpose_a(false);
op.set_transpose_b(true);

auto out_shapes =
ov::op::v0::shape_infer(&op,
std::vector<ov::PartialShape>{get_input_partial_shape(0), get_input_partial_shape(1)});

auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type;
set_output_type(0, output_type, out_shapes[0]);
}

bool FullyConnectedQuantized::visit_attributes(ov::AttributeVisitor& visitor) {
visitor.on_attribute("output_type", m_output_type);
return true;
}

} // namespace internal
} // namespace op
} // namespace ov
34 changes: 34 additions & 0 deletions src/common/transformations/src/ov_ops/placeholder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ov_ops/placeholder.hpp"

#include "transformations/rt_info/fused_names_attribute.hpp"

namespace ov {
namespace op {
namespace internal {

Placeholder::Placeholder() : ov::op::Op() {
validate_and_infer_types();
set_friendly_name(get_name());
get_rt_info().emplace(FusedNames::get_type_info_static(), FusedNames{get_friendly_name()});
}

bool Placeholder::visit_attributes(ov::AttributeVisitor& visitor) {
return true;
}

void Placeholder::validate_and_infer_types() {
set_output_type(0, ov::element::undefined, ov::PartialShape{});
}

std::shared_ptr<Node> Placeholder::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<Placeholder>();
}

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/convert_fc_to_quantized.hpp"

#include <memory>

#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "ov_ops/fully_connected.hpp"
#include "ov_ops/fully_connected_quantized.hpp"
#include "transformations/utils/utils.hpp"

ov::pass::ConvertFullyConnectedToFullyConnectedQuantized::ConvertFullyConnectedToFullyConnectedQuantized() {
using namespace ov::pass::pattern;

auto quantized_weights = [](const ov::Output<ov::Node>& output) {
return output.get_element_type() == ov::element::i8;
};

auto quantized_activations = [](const ov::Output<ov::Node>& output) {
return output.get_element_type() == ov::element::u8 || output.get_element_type() == ov::element::i8;
};

auto activations_m = pattern::any_input(quantized_activations);
auto weights_m = wrap_type<ov::op::v0::Constant>(quantized_weights);
auto bias_m = wrap_type<ov::op::v0::Constant>();

auto fully_connected_m = wrap_type<ov::op::internal::FullyConnected>({activations_m, weights_m, bias_m});
auto dequantization_scales_m = wrap_type<ov::op::v0::Constant>();
auto multiply_m = wrap_type<ov::op::v1::Multiply>({fully_connected_m, dequantization_scales_m});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

OPENVINO_ASSERT(pattern_map.count(fully_connected_m));
OPENVINO_ASSERT(pattern_map.count(multiply_m));
OPENVINO_ASSERT(pattern_map.count(dequantization_scales_m));

std::vector<std::shared_ptr<ov::Node>> result_nodes = {};

auto fc = std::dynamic_pointer_cast<ov::op::internal::FullyConnected>(
pattern_map.at(fully_connected_m).get_node_shared_ptr());
auto activations = pattern_map.at(activations_m).get_node_shared_ptr();
auto weights = pattern_map.at(weights_m).get_node_shared_ptr();
auto bias = pattern_map.at(bias_m).get_node_shared_ptr();
auto multiply = pattern_map.at(multiply_m).get_node_shared_ptr();
auto dequantization_scales = pattern_map.at(dequantization_scales_m).get_node_shared_ptr();
const auto& fc_output_shape = fc->get_output_shape(0);
const auto& multiply_output_shape = multiply->get_output_shape(0);

if (fc_output_shape.back() != multiply_output_shape.back()) {
return false;
}

auto fc_quantized = std::make_shared<ov::op::internal::FullyConnectedQuantized>(activations,
weights,
bias,
dequantization_scales,
fc->get_output_type());

// result_nodes.push_back(new_gather_node);
fc_quantized->set_friendly_name(multiply->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), result_nodes);
ov::replace_node(multiply, fc_quantized);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(multiply_m, "ConvertFullyConnectedToFullyConnectedQuantized");
this->register_matcher(m, callback);
}
Loading

0 comments on commit 983cfbe

Please sign in to comment.