Skip to content

Commit

Permalink
Add cache rotation inputs in transformations, CPU and GPU plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Nov 19, 2024
1 parent 5bde1ab commit a33f255
Show file tree
Hide file tree
Showing 31 changed files with 1,569 additions and 381 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/job_cxx_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ jobs:
${{ env.SETUPVARS_COMMAND }}
${{ env.INSTALL_TEST_DIR }}/ov_cpu_unit_tests --gtest_print_time=1 --gtest_output=xml:${{ env.INSTALL_TEST_DIR }}/TEST-CPUUnitTests.xml
- name: CPU plugin unit tests (vectorized)
if: fromJSON(inputs.affected-components).CPU.test
run: |
${{ env.SETUPVARS_COMMAND }}
${{ env.INSTALL_TEST_DIR }}/ov_cpu_unit_tests_vectorized --gtest_print_time=1 --gtest_output=xml:${{ env.INSTALL_TEST_DIR }}/TEST-CPUUnitTestsVectorized.xml
- name: ov_subgraphs_dumper_tests tests
run: |
${{ env.SETUPVARS_COMMAND }}
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,4 @@ endif()
# provides a callback function to describe each component in repo
include(cmake/packaging/packaging.cmake)

ov_cpack(${OV_CPACK_COMPONENTS_ALL})
ov_cpack(${OV_CPACK_COMPONENTS_ALL})
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,20 @@ void regmodule_offline_transformations(py::module m) {

m_offline_transformations.def(
"paged_attention_transformation",
[](std::shared_ptr<ov::Model> model, bool use_block_indices_inputs, bool use_score_outputs) {
[](std::shared_ptr<ov::Model> model,
bool use_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation) {
ov::pass::Manager manager;
manager.register_pass<ov::pass::SDPAToPagedAttention>(use_block_indices_inputs, use_score_outputs);
manager.register_pass<ov::pass::SDPAToPagedAttention>(use_block_indices_inputs,
use_score_outputs,
allow_cache_rotation);
manager.run_passes(model);
},
py::arg("model"),
py::arg("use_block_indices_inputs") = false,
py::arg("use_score_outputs") = false);
py::arg("use_score_outputs") = false,
py::arg("allow_cache_rotation") = false);

m_offline_transformations.def(
"stateful_to_stateless_transformation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ class ov::pass::StateManagementPattern : public ov::pass::MatcherPass {
ParameterVector& parameters_to_remove,
int& layer_index,
ov::Output<Node> max_context_len,
ParameterVector& block_indices_inputs,
ParameterVector& block_indices_inputs_for_each_layer,
ResultVector& score_results,
bool use_block_indices,
bool use_score_outputs);
bool use_per_layer_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation,
ParameterVector& rotation_coefficients_inputs_for_each_layer,
ParameterVector& rotated_block_indices_inputs_for_each_layer);
};
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,30 @@ static node_tuple kv_read_and_concat(ov::Output<ov::Node> kv_current) {
return node_tuple(kv_past_par, kv_current2, kv_current_reshaped, kv_concat);
}

template <class T>
void insert_rotation_inputs_as(OutputVector& pa_arguments, size_t layer_index) {
auto rotation_coefficients = setName(std::make_shared<T>(ov::element::f32, ov::PartialShape{-1}),
"rotation_coefficients." + std::to_string(layer_index - 1));
auto rotated_block_indices = setName(std::make_shared<T>(ov::element::i32, ov::PartialShape{-1}),
"rotated_block_indices." + std::to_string(layer_index - 1));

pa_arguments.insert(pa_arguments.begin() + 13, rotation_coefficients);
pa_arguments.insert(pa_arguments.begin() + 14, rotated_block_indices);
}

ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_parameters,
ParameterVector& model_remaining_params,
const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
ParameterVector& parameters_to_remove,
int& layer_index,
Output<Node> max_context_len,
ParameterVector& block_indices_inputs,
ParameterVector& block_indices_inputs_for_each_layer,
ResultVector& score_results,
bool use_block_indices_inputs,
bool use_score_outputs) {
bool use_per_layer_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation,
ParameterVector& rotation_coefficients_inputs_for_each_layer,
ParameterVector& rotated_block_indices_inputs_for_each_layer) {
MATCHER_SCOPE(StateManagementPattern);

auto k_current = pattern::any_input();
Expand Down Expand Up @@ -176,9 +190,11 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
&model_remaining_params,
&sliding_window,
&parameters_to_remove,
&block_indices_inputs,
&block_indices_inputs_for_each_layer,
&score_results,
&layer_index](ov::pass::pattern::Matcher& m) {
&layer_index,
&rotation_coefficients_inputs_for_each_layer,
&rotated_block_indices_inputs_for_each_layer](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto real_q = pattern_map.at(q);

Expand Down Expand Up @@ -374,11 +390,26 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
max_context_len.get_node_shared_ptr()};
pa_arguments.insert(pa_arguments.end(), additional_params.begin(), additional_params.end());

if (use_block_indices_inputs) {
if (use_per_layer_block_indices_inputs) {
auto block_indices = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}),
"block_indices." + std::to_string(layer_index - 1));
pa_arguments.insert(pa_arguments.begin() + 7, block_indices);
block_indices_inputs.push_back(block_indices);
block_indices_inputs_for_each_layer.push_back(block_indices);
}

OPENVINO_ASSERT(pa_arguments.size() == 13);

if (allow_cache_rotation) {
auto rotation_coefficients = setName(std::make_shared<v0::Parameter>(element::f32, PartialShape{-1}),
"rotation_coefficients." + std::to_string(layer_index - 1));
auto rotated_block_indices = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}),
"rotated_block_indices." + std::to_string(layer_index - 1));

pa_arguments.insert(pa_arguments.begin() + 13, rotation_coefficients);
pa_arguments.insert(pa_arguments.begin() + 14, rotated_block_indices);

rotation_coefficients_inputs_for_each_layer.push_back(rotation_coefficients);
rotated_block_indices_inputs_for_each_layer.push_back(rotated_block_indices);
}

auto paged_attention = std::make_shared<ov::op::PagedAttentionExtension>(pa_arguments);
Expand Down Expand Up @@ -435,4 +466,4 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par

auto m = std::make_shared<ov::pass::pattern::Matcher>(sdpa_variants, matcher_name);
register_matcher(m, callback);
}
}
7 changes: 5 additions & 2 deletions src/core/include/openvino/pass/sdpa_to_paged_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ class OPENVINO_API SDPAToPagedAttention : public ModelPass {
public:
OPENVINO_RTTI("SDPAToPagedAttention");

SDPAToPagedAttention(bool use_block_indices_inputs = false, bool use_score_outputs = false);
SDPAToPagedAttention(bool use_per_layer_block_indices_inputs = false,
bool use_score_outputs = false,
bool allow_cache_rotation = false);
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;

private:
bool m_use_block_indices_inputs;
bool m_use_per_layer_block_indices_inputs;
bool m_use_score_outputs;
bool m_allow_cache_rotation;
};
} // namespace pass
} // namespace ov
30 changes: 28 additions & 2 deletions src/core/src/op/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ void PagedAttentionExtension::validate_and_infer_types() {
OV_OP_SCOPE(PagedAttentionExtension_validate_and_infer_types);

NODE_VALIDATION_CHECK(this,
get_input_size() == 13,
"PagedAttensionExtension expects 13 inputs, but it has ",
get_input_size() == 13 || get_input_size() == 15,
"PagedAttensionExtension expects 13 or 15 inputs, but it has ",
get_input_size());

NODE_VALIDATION_CHECK(
Expand Down Expand Up @@ -147,6 +147,32 @@ void PagedAttentionExtension::validate_and_infer_types() {
get_input_element_type(12),
".");

if (get_input_size() == 15) {
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(13).rank().is_dynamic() || get_input_partial_shape(13).rank().get_length() == 1,
"Input `rotation_coefficients` should either have rank 1 or be omitted, but it has rank ",
get_input_partial_shape(13).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(13).is_dynamic() || get_input_element_type(13) == element::f32,
"Element type of `rotation_coefficients` input should be f32, but it is ",
get_input_element_type(13),
".");

NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(14).rank().is_dynamic() || get_input_partial_shape(14).rank().get_length() == 1,
"Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ",
get_input_partial_shape(14).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(14).is_dynamic() || get_input_element_type(14) == element::i32,
"Element type of `rotated_block_indices` input should be i32, but it is ",
get_input_element_type(14),
".");
}

// value head_size may be not same with key
auto out_ps = get_input_partial_shape(0);
const auto& key_ps = get_input_partial_shape(1);
Expand Down
33 changes: 23 additions & 10 deletions src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

using namespace ov::op;

ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_block_indices_inputs, bool use_score_outputs)
: m_use_block_indices_inputs(use_block_indices_inputs),
m_use_score_outputs(use_score_outputs) {}
ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation)
: m_use_per_layer_block_indices_inputs(use_per_layer_block_indices_inputs),
m_use_score_outputs(use_score_outputs),
m_allow_cache_rotation(allow_cache_rotation) {}

static std::shared_ptr<v0::Parameter> setName(std::shared_ptr<v0::Parameter> node, const char* name) {
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
Expand All @@ -46,7 +49,7 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "subsequence_begins"),
setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "block_indices_begins"),
};
if (!m_use_block_indices_inputs) {
if (!m_use_per_layer_block_indices_inputs) {
auto block_indices = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "block_indices");
model_remaining_params.insert(model_remaining_params.begin() + 2, block_indices);
}
Expand Down Expand Up @@ -94,7 +97,9 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
ParameterVector kv_parameters;
ParameterVector parameters_to_remove;
ResultVector results_to_remove; // # used, but cannot really track all Results in stateless model
ParameterVector block_indices_inputs;
ParameterVector block_indices_inputs_for_each_layer;
ParameterVector rotation_coefficients_inputs_for_each_layer;
ParameterVector rotated_block_indices_inputs_for_each_layer;
ResultVector score_results;

std::shared_ptr<v0::Parameter> position_ids;
Expand Down Expand Up @@ -123,10 +128,13 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
parameters_to_remove,
layer_index,
max_context_len->output(0),
block_indices_inputs,
block_indices_inputs_for_each_layer,
score_results,
m_use_block_indices_inputs,
m_use_score_outputs);
m_use_per_layer_block_indices_inputs,
m_use_score_outputs,
m_allow_cache_rotation,
rotation_coefficients_inputs_for_each_layer,
rotated_block_indices_inputs_for_each_layer);
manager.register_pass<PrevSequenceLengthPattern>(prev_max_seq_len, batch_dim);
manager.register_pass<TotalSequenceLengthPattern>(max_context_len);
manager.register_pass<PositionIDsReplacer>(unsqueezed_position_ids->output(0));
Expand Down Expand Up @@ -174,14 +182,19 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
model->remove_parameter(parameter);
}

if (m_use_block_indices_inputs) {
model->add_parameters(block_indices_inputs);
if (m_use_per_layer_block_indices_inputs) {
model->add_parameters(block_indices_inputs_for_each_layer);
}

if (m_use_score_outputs) {
model->add_results(score_results);
}

if (m_allow_cache_rotation) {
model->add_parameters(rotation_coefficients_inputs_for_each_layer);
model->add_parameters(rotated_block_indices_inputs_for_each_layer);
}

model->add_parameters(kv_parameters);
model->add_parameters(model_remaining_params);
model->add_parameters({std::move(max_context_len)});
Expand Down
Loading

0 comments on commit a33f255

Please sign in to comment.