Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets][CPU] Added external repacking via BrgemmCopyB #28179

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/common/snippets/include/snippets/op/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ class Reshape : public ov::op::Op {
ov::PartialShape m_target_shape = {};
};

/**
* @interface ReshapeWithOrder
* @brief ReshapeWithOrder reshapes input tensor shape by reqiured target order.
* The tensor data is not updated.
* Note: Order is stored in input PortDescriptor
* @ingroup snippets
*/
class ReshapeWithOrder : public ov::op::Op {
public:
OPENVINO_OP("ReshapeWithOrder", "SnippetsOpset");
ReshapeWithOrder() = default;
ReshapeWithOrder(const Output<Node>& x, std::vector<size_t> order);

bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;

private:
void custom_constructor_validate_and_infer_types(std::vector<size_t> order);
};

} // namespace op
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,13 @@ class ReshapeShapeInfer : public IShapeInferSnippets {
explicit ReshapeShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

class ReshapeWithOrderShapeInfer : public IShapeInferSnippets {
std::vector<size_t> m_target_order {};
public:
explicit ReshapeWithOrderShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ OV_OP(LoopEnd, ov::snippets::op)
OV_OP(Brgemm, ov::snippets::op)
OV_OP(BroadcastLoad, ov::snippets::op)
OV_OP(Reshape, ov::snippets::op)
OV_OP(ReshapeWithOrder, ov::snippets::op)

OV_OP(Store, ov::snippets::op)

Expand Down
15 changes: 14 additions & 1 deletion src/common/snippets/include/snippets/utils/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,26 @@ std::shared_ptr<ov::Node> get_leaf_node_of_first_child_shape_infer_seq(const std
std::shared_ptr<ov::Node> get_leaf_node_of_first_parent_shape_infer_seq(const std::shared_ptr<ov::Node>& start_node);

/**
*
* @param Get stride of input/output dimension
* @param expr_port target port that contains shape and layout info
* @param idx index of the target dimension starting from the shape's end (default = 1)
*/

int64_t get_dim_stride(const lowered::ExpressionPort& expr_port, size_t idx = 1);
/**
* @brief Get stride of input dimension
* @param shape target shape
* @param layout target layout
* @param idx index of the target dimension starting from the shape's end (default = 1)
*/
int64_t get_dim_in_stride(const VectorDims& shape, const VectorDims& layout, size_t idx = 1);
/**
* @brief Get stride of output dimension
* @param shape target shape
* @param layout target layout
* @param idx index of the target dimension starting from the shape's end (default = 1)
*/
int64_t get_dim_out_stride(const VectorDims& shape, const VectorDims& layout, size_t idx = 1);

/**
* @brief Traverses path starting from "expr", and calls "func" for each expression.
Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ RegType Generator::get_op_out_reg_type(const ov::Output<Node>& out) const {
std::dynamic_pointer_cast<op::Buffer>(op) ||
std::dynamic_pointer_cast<op::RankNormalization>(op) ||
std::dynamic_pointer_cast<op::Reshape>(op) ||
std::dynamic_pointer_cast<op::ReshapeWithOrder>(op) ||
std::dynamic_pointer_cast<snippets::op::Store>(op)
#ifdef SNIPPETS_DEBUG_CAPS
|| std::dynamic_pointer_cast<op::PerfCountBeginBase>(op)
Expand Down
41 changes: 41 additions & 0 deletions src/common/snippets/src/op/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
namespace ov {
namespace snippets {
namespace op {

Reshape::Reshape(const Output<Node>& arg, ov::PartialShape target_shape)
: Op({arg}), m_target_shape(std::move(target_shape)) {
constructor_validate_and_infer_types();
Expand Down Expand Up @@ -38,6 +39,46 @@ const ov::PartialShape& Reshape::get_target_shape() const {
void Reshape::set_target_shape(ov::PartialShape shape) {
m_target_shape = std::move(shape);
}

ReshapeWithOrder::ReshapeWithOrder(const Output<Node>& arg, std::vector<size_t> order)
: Op({arg}) {
custom_constructor_validate_and_infer_types(std::move(order));
}

void ReshapeWithOrder::custom_constructor_validate_and_infer_types(std::vector<size_t> order) {
INTERNAL_OP_SCOPE(ReshapeWithOrder_constructor_validate_and_infer_types);

const auto& input_pshape = get_input_partial_shape(0);
OPENVINO_ASSERT(input_pshape.rank().is_static() && input_pshape.size() == order.size(),
"Incompatible shape and order sizes");

// During ctor call, ReshapeWithOrder doesn't know his port descriptors.
// So we use explicit layouts from parameters
set_output_type(0, get_input_element_type(0), ov::snippets::utils::get_planar_pshape(input_pshape, order));
}

void ReshapeWithOrder::validate_and_infer_types() {
const auto& input_pshape = get_input_partial_shape(0);
const auto order = lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout();
OPENVINO_ASSERT(input_pshape.rank().is_static() && input_pshape.size() == order.size(),
"Incompatible shape and order sizes");
const auto output_pshape = utils::get_planar_pshape(get_input_partial_shape(0), order);
set_output_type(0, get_input_element_type(0), output_pshape);
}

std::shared_ptr<Node> ReshapeWithOrder::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(ReshapeWithOrder);
check_new_args_count(this, new_args);
const auto& order = lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout();
return std::make_shared<ReshapeWithOrder>(new_args.at(0), order);
}

bool ReshapeWithOrder::visit_attributes(AttributeVisitor& visitor) {
auto order = lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout();
visitor.on_attribute("target_order", order);
return true;
}

}// namespace op
}// namespace snippets
}// namespace ov
1 change: 1 addition & 0 deletions src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ auto Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bo

auto Subgraph::is_shape_infer_op(const std::shared_ptr<ov::Node>& op) -> bool {
return ov::is_type<snippets::op::Reshape>(op) ||
ov::is_type<snippets::op::ReshapeWithOrder>(op) ||
ov::is_type<snippets::op::RankNormalization>(op);
}

Expand Down
19 changes: 18 additions & 1 deletion src/common/snippets/src/runtime_configurator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,23 @@ void RuntimeConfigurator::init_data_info(const lowered::LinearIRCPtr& linear_ir)
// input->shape changing ops->load
PortDescriptorPtr desc = nullptr;
const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(param);
const auto& mem_desc_expr = shape_infer_seq.empty() ? param : shape_infer_seq.back();
ExpressionPtr mem_desc_expr = param;
if (!shape_infer_seq.empty()) {
// If there is ReshapeWithOrder, we should take its desc because it affects on shape by target order
const auto& reordered_reshape_it = std::find_if(shape_infer_seq.cbegin(), shape_infer_seq.cend(),
[](const ExpressionPtr& expr) {
return ov::is_type<op::ReshapeWithOrder>(expr->get_node());
});
if (reordered_reshape_it != shape_infer_seq.cend()) {
const auto& reshape = *reordered_reshape_it;
const auto& etype = reshape->get_node()->get_output_element_type(0);
update_io_parameters(reshape->get_input_port_descriptor(0), etype);
continue;
}

mem_desc_expr = shape_infer_seq.back();
}

auto consumer_inputs = mem_desc_expr->get_output_port_connector(0)->get_consumers();
for (const auto& child_input : consumer_inputs) {
const auto ma = std::dynamic_pointer_cast<snippets::modifier::MemoryAccess>(child_input.get_expr()->get_node());
Expand All @@ -127,6 +143,7 @@ void RuntimeConfigurator::init_data_info(const lowered::LinearIRCPtr& linear_ir)
break;
}
}
OPENVINO_ASSERT(desc, "Descriptor is missed!");
const auto& etype = mem_desc_expr->get_node()->get_output_element_type(0);
update_io_parameters(desc, etype);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,16 @@ Result ReshapeShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes)
return {{target_shape}, ShapeInferStatus::success};
}

ReshapeWithOrderShapeInfer::ReshapeWithOrderShapeInfer(const std::shared_ptr<Node>& n) {
const auto& reshape = as_type_ptr<ov::snippets::op::ReshapeWithOrder>(n);
OPENVINO_ASSERT(reshape, "Invalid node passed to ReshapeWithOrderShapeInfer.");
m_target_order = lowered::PortDescriptorUtils::get_port_descriptor_ptr(reshape->input(0))->get_layout();
}

Result ReshapeWithOrderShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 1, "Invalid number of shapes is passed in ReshapeWithOrderShapeInfer");
return {{ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_target_order)}, ShapeInferStatus::success};
}

} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry
SHAPE_INFER_PREDEFINED(op::KernelDynamic, EmptyShapeInfer),
SHAPE_INFER_PREDEFINED(op::Nop, EmptyShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::Reshape, ReshapeShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::ReshapeWithOrder, ReshapeWithOrderShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(opset1::Select, SelectShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::Brgemm, BrgemmShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::ReduceMax, ReduceShapeInfer),
Expand Down
17 changes: 12 additions & 5 deletions src/common/snippets/src/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,21 @@ std::shared_ptr<ov::Node> get_leaf_node_of_first_parent_shape_infer_seq(const st
}

int64_t get_dim_stride(const lowered::ExpressionPort& expr_port, size_t idx) {
size_t dim_idx = 0;
const auto& shape = expr_port.get_descriptor_ptr()->get_shape();
const auto& layout = expr_port.get_descriptor_ptr()->get_layout();
switch (expr_port.get_type()) {
case lowered::ExpressionPort::Input: dim_idx = utils::get_input_dim_idx(layout, idx); break;
case lowered::ExpressionPort::Output: dim_idx = utils::get_output_dim_idx(layout, idx); break;
default: OPENVINO_THROW("Unsupported expression port type!");
case lowered::ExpressionPort::Input: return get_dim_in_stride(shape, layout, idx);
case lowered::ExpressionPort::Output: return get_dim_out_stride(shape, layout, idx);
}
return get_stride(dim_idx, expr_port.get_descriptor_ptr()->get_shape());
OPENVINO_THROW("Unsupported expression port type!");
}

int64_t get_dim_in_stride(const VectorDims& shape, const VectorDims& layout, size_t idx) {
return get_stride(utils::get_input_dim_idx(layout, idx), shape);
}

int64_t get_dim_out_stride(const VectorDims& shape, const VectorDims& layout, size_t idx) {
return get_stride(utils::get_output_dim_idx(layout, idx), shape);
}

void visit_path(const lowered::ExpressionPtr& expr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,45 @@ std::string CPURuntimeConfig::to_string() const {
}
#endif

CPURuntimeConfigurator::CPURuntimeConfigurator()
: ov::snippets::RuntimeConfigurator(std::make_shared<CPURuntimeConfig>()) {}
#ifndef OPENVINO_ARCH_ARM64

CPURuntimeConfig::RepackedInput::RepackedInput(std::shared_ptr<const BrgemmCopyBKernel> kernel,
CpuBlockedMemoryDescPtr desc,
VectorDims in_offsets,
VectorDims out_offsets)
: m_kernel(std::move(kernel)),
m_desc(std::move(desc)),
m_in_offsets(std::move(in_offsets)),
m_out_offsets(std::move(out_offsets)) {
OPENVINO_ASSERT(m_in_offsets.size() == m_out_offsets.size(), "Incorrect size of offsets");
OPENVINO_ASSERT(m_desc, "Descriptor is empty");
}

const CpuBlockedMemoryDescPtr& CPURuntimeConfig::RepackedInput::desc() const {
return m_desc;
}

const std::shared_ptr<const BrgemmCopyBKernel>& CPURuntimeConfig::RepackedInput::kernel() const {
return m_kernel;
}

const VectorDims& CPURuntimeConfig::RepackedInput::in_offsets() const {
return m_in_offsets;
}

const VectorDims& CPURuntimeConfig::RepackedInput::out_offsets() const {
return m_out_offsets;
}

#endif // OPENVINO_ARCH_ARM64

CPURuntimeConfigurator::CPURuntimeConfigurator(ov::intel_cpu::MultiCacheWeakPtr cache)
: ov::snippets::RuntimeConfigurator(std::make_shared<CPURuntimeConfig>()),
compiled_kernel_cache(std::move(cache)) {}

void CPURuntimeConfigurator::initialization(const ov::snippets::lowered::LinearIRCPtr& linear_ir) {
RuntimeConfigurator::initialization(linear_ir);
#ifndef OPENVINO_ARCH_ARM64
#ifdef OPENVINO_ARCH_X86_64
RuntimeOptimizer::register_if_applicable<BrgemmCopyBLoopPortsAdjuster>(m_intermediate_optimizers, linear_ir, this);
RuntimeOptimizer::register_if_applicable<BrgemmExternalRepackingAdjuster>(m_final_optimizers, linear_ir, this);
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
#pragma once

#include "emitters/snippets/jit_snippets_call_args.hpp"

#ifndef OPENVINO_ARCH_ARM64
# include "emitters/snippets/x64/kernel_executors/brgemm_copy_b.hpp"
#endif

#include "cache/multi_cache.h"
#include "memory_desc/cpu_blocked_memory_desc.h"
#include "snippets/lowered/port_descriptor.hpp"
#include "snippets/runtime_configurator.hpp"
Expand All @@ -21,27 +27,61 @@ class CPURuntimeConfig : public ov::snippets::RuntimeConfig {
std::string to_string() const override;
#endif

#ifndef OPENVINO_ARCH_ARM64
struct RepackedInput {
RepackedInput() = default;
RepackedInput(std::shared_ptr<const BrgemmCopyBKernel> kernel,
CpuBlockedMemoryDescPtr desc,
VectorDims in_offsets,
VectorDims out_offsets);

const std::shared_ptr<const BrgemmCopyBKernel>& kernel() const;
const CpuBlockedMemoryDescPtr& desc() const;
const VectorDims& in_offsets() const;
const VectorDims& out_offsets() const;

private:
std::shared_ptr<const BrgemmCopyBKernel> m_kernel{nullptr};
CpuBlockedMemoryDescPtr m_desc{nullptr};
VectorDims m_in_offsets{};
VectorDims m_out_offsets{};
};
std::unordered_map<size_t, RepackedInput> repacked_inputs = {};

enum class RepackingImplType {
NONE, // no kernel-outside repacking
IN_PARALLEL, // should be executed in parallel_nt by each thread
SEPARATE, // should be separathy from kernel executed
};
RepackingImplType repacking_impl_type = RepackingImplType::NONE;
#endif // OPENVINO_ARCH_ARM64

std::vector<jit_snippets_call_args::loop_args_t> loop_args = {};
std::unordered_map<size_t, CpuBlockedMemoryDescPtr> m_in_requested_descs = {};
};

class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator {
public:
CPURuntimeConfigurator();
CPURuntimeConfigurator(ov::intel_cpu::MultiCacheWeakPtr cache = {});

/**
* @brief Calculate Loop parameters of Loop emitters and update these values in CPURuntimeConfig
* @param linear_ir LinearIR
*/
void update_loop_args(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const;

const ov::intel_cpu::MultiCacheWeakPtr& get_cache() const {
return compiled_kernel_cache;
}

protected:
void update(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override;
void update_tensor_rank(const ov::snippets::VectorDims& master_shape) const override;
void init_tensor_rank(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const override;
void initialization(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override;

static const size_t rank6D;

ov::intel_cpu::MultiCacheWeakPtr compiled_kernel_cache;
};

} // namespace intel_cpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class jit_snippet : public dnnl::impl::cpu::x64::jit_generator {

intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa,
ov::intel_cpu::MultiCacheWeakPtr cache)
: TargetMachine(std::make_shared<CPURuntimeConfigurator>()),
: TargetMachine(std::make_shared<CPURuntimeConfigurator>(cache)),
h(new jit_snippet()),
isa(host_isa),
compiled_kernel_cache(std::move(cache)) {
Expand All @@ -177,6 +177,8 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
jitters[snippets::op::RankNormalization::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter);
jitters[snippets::op::Reshape::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter);
jitters[snippets::op::ReshapeWithOrder::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter);

jitters[snippets::op::Load::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_memory_emitter);
jitters[snippets::op::LoadReshape::get_type_info_static()] =
Expand Down
Loading
Loading