diff --git a/.github/workflows/job_cxx_unit_tests.yml b/.github/workflows/job_cxx_unit_tests.yml index 8fab17043b7465..f0c18233a692c8 100644 --- a/.github/workflows/job_cxx_unit_tests.yml +++ b/.github/workflows/job_cxx_unit_tests.yml @@ -195,6 +195,12 @@ jobs: ${{ env.SETUPVARS_COMMAND }} ${{ env.INSTALL_TEST_DIR }}/ov_cpu_unit_tests --gtest_print_time=1 --gtest_output=xml:${{ env.INSTALL_TEST_DIR }}/TEST-CPUUnitTests.xml + - name: CPU plugin unit tests (vectorized) + if: fromJSON(inputs.affected-components).CPU.test + run: | + ${{ env.SETUPVARS_COMMAND }} + ${{ env.INSTALL_TEST_DIR }}/ov_cpu_unit_tests_vectorized --gtest_print_time=1 --gtest_output=xml:${{ env.INSTALL_TEST_DIR }}/TEST-CPUUnitTestsVectorized.xml + - name: ov_subgraphs_dumper_tests tests run: | ${{ env.SETUPVARS_COMMAND }} diff --git a/CMakeLists.txt b/CMakeLists.txt index e9e8d3724d9ac5..3ae556a3b39a59 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,4 +185,4 @@ endif() # provides a callback function to describe each component in repo include(cmake/packaging/packaging.cmake) -ov_cpack(${OV_CPACK_COMPONENTS_ALL}) \ No newline at end of file +ov_cpack(${OV_CPACK_COMPONENTS_ALL}) diff --git a/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp b/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp index 641893cdd267a2..0e1c2387af43c2 100644 --- a/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp +++ b/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp @@ -132,14 +132,20 @@ void regmodule_offline_transformations(py::module m) { m_offline_transformations.def( "paged_attention_transformation", - [](std::shared_ptr model, bool use_block_indices_inputs, bool use_score_outputs) { + [](std::shared_ptr model, + bool use_block_indices_inputs, + bool use_score_outputs, + bool allow_cache_rotation) { ov::pass::Manager manager; - manager.register_pass(use_block_indices_inputs, use_score_outputs); + manager.register_pass(use_block_indices_inputs, + use_score_outputs, + allow_cache_rotation); manager.run_passes(model); }, py::arg("model"), py::arg("use_block_indices_inputs") = false, - py::arg("use_score_outputs") = false); + py::arg("use_score_outputs") = false, + py::arg("allow_cache_rotation") = false); m_offline_transformations.def( "stateful_to_stateless_transformation", diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp index feab06ccc0cd5d..851ba55648f499 100644 --- a/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp @@ -24,8 +24,11 @@ class ov::pass::StateManagementPattern : public ov::pass::MatcherPass { ParameterVector& parameters_to_remove, int& layer_index, ov::Output max_context_len, - ParameterVector& block_indices_inputs, + ParameterVector& block_indices_inputs_for_each_layer, ResultVector& score_results, - bool use_block_indices, - bool use_score_outputs); + bool use_per_layer_block_indices_inputs, + bool use_score_outputs, + bool allow_cache_rotation, + ParameterVector& rotation_coefficients_inputs_for_each_layer, + ParameterVector& rotated_block_indices_inputs_for_each_layer); }; \ No newline at end of file diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp index 28e7cd90019b34..e47ec4731a17d6 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp @@ -64,16 +64,30 @@ static node_tuple kv_read_and_concat(ov::Output kv_current) { return node_tuple(kv_past_par, kv_current2, kv_current_reshaped, kv_concat); } +template +void insert_rotation_inputs_as(OutputVector& pa_arguments, size_t layer_index) { + auto rotation_coefficients = setName(std::make_shared(ov::element::f32, ov::PartialShape{-1}), + "rotation_coefficients." + std::to_string(layer_index - 1)); + auto rotated_block_indices = setName(std::make_shared(ov::element::i32, ov::PartialShape{-1}), + "rotated_block_indices." + std::to_string(layer_index - 1)); + + pa_arguments.insert(pa_arguments.begin() + 13, rotation_coefficients); + pa_arguments.insert(pa_arguments.begin() + 14, rotated_block_indices); +} + ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_parameters, ParameterVector& model_remaining_params, const std::shared_ptr& sliding_window, ParameterVector& parameters_to_remove, int& layer_index, Output max_context_len, - ParameterVector& block_indices_inputs, + ParameterVector& block_indices_inputs_for_each_layer, ResultVector& score_results, - bool use_block_indices_inputs, - bool use_score_outputs) { + bool use_per_layer_block_indices_inputs, + bool use_score_outputs, + bool allow_cache_rotation, + ParameterVector& rotation_coefficients_inputs_for_each_layer, + ParameterVector& rotated_block_indices_inputs_for_each_layer) { MATCHER_SCOPE(StateManagementPattern); auto k_current = pattern::any_input(); @@ -176,9 +190,11 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par &model_remaining_params, &sliding_window, ¶meters_to_remove, - &block_indices_inputs, + &block_indices_inputs_for_each_layer, &score_results, - &layer_index](ov::pass::pattern::Matcher& m) { + &layer_index, + &rotation_coefficients_inputs_for_each_layer, + &rotated_block_indices_inputs_for_each_layer](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); auto real_q = pattern_map.at(q); @@ -374,11 +390,26 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par max_context_len.get_node_shared_ptr()}; pa_arguments.insert(pa_arguments.end(), additional_params.begin(), additional_params.end()); - if (use_block_indices_inputs) { + if (use_per_layer_block_indices_inputs) { auto block_indices = setName(std::make_shared(element::i32, PartialShape{-1}), "block_indices." + std::to_string(layer_index - 1)); pa_arguments.insert(pa_arguments.begin() + 7, block_indices); - block_indices_inputs.push_back(block_indices); + block_indices_inputs_for_each_layer.push_back(block_indices); + } + + OPENVINO_ASSERT(pa_arguments.size() == 13); + + if (allow_cache_rotation) { + auto rotation_coefficients = setName(std::make_shared(element::f32, PartialShape{-1}), + "rotation_coefficients." + std::to_string(layer_index - 1)); + auto rotated_block_indices = setName(std::make_shared(element::i32, PartialShape{-1}), + "rotated_block_indices." + std::to_string(layer_index - 1)); + + pa_arguments.insert(pa_arguments.begin() + 13, rotation_coefficients); + pa_arguments.insert(pa_arguments.begin() + 14, rotated_block_indices); + + rotation_coefficients_inputs_for_each_layer.push_back(rotation_coefficients); + rotated_block_indices_inputs_for_each_layer.push_back(rotated_block_indices); } auto paged_attention = std::make_shared(pa_arguments); @@ -435,4 +466,4 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par auto m = std::make_shared(sdpa_variants, matcher_name); register_matcher(m, callback); -} \ No newline at end of file +} diff --git a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp index a0dd403818b462..ce7eb78d079632 100644 --- a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp +++ b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp @@ -19,12 +19,15 @@ class OPENVINO_API SDPAToPagedAttention : public ModelPass { public: OPENVINO_RTTI("SDPAToPagedAttention"); - SDPAToPagedAttention(bool use_block_indices_inputs = false, bool use_score_outputs = false); + SDPAToPagedAttention(bool use_per_layer_block_indices_inputs = false, + bool use_score_outputs = false, + bool allow_cache_rotation = false); bool run_on_model(const std::shared_ptr& model) override; private: - bool m_use_block_indices_inputs; + bool m_use_per_layer_block_indices_inputs; bool m_use_score_outputs; + bool m_allow_cache_rotation; }; } // namespace pass } // namespace ov diff --git a/src/core/src/op/paged_attention.cpp b/src/core/src/op/paged_attention.cpp index cdcb66e86ee33e..3c82d86817b51c 100644 --- a/src/core/src/op/paged_attention.cpp +++ b/src/core/src/op/paged_attention.cpp @@ -19,8 +19,8 @@ void PagedAttentionExtension::validate_and_infer_types() { OV_OP_SCOPE(PagedAttentionExtension_validate_and_infer_types); NODE_VALIDATION_CHECK(this, - get_input_size() == 13, - "PagedAttensionExtension expects 13 inputs, but it has ", + get_input_size() == 13 || get_input_size() == 15, + "PagedAttensionExtension expects 13 or 15 inputs, but it has ", get_input_size()); NODE_VALIDATION_CHECK( @@ -147,6 +147,32 @@ void PagedAttentionExtension::validate_and_infer_types() { get_input_element_type(12), "."); + if (get_input_size() == 15) { + NODE_VALIDATION_CHECK( + this, + get_input_partial_shape(13).rank().is_dynamic() || get_input_partial_shape(13).rank().get_length() == 1, + "Input `rotation_coefficients` should either have rank 1 or be omitted, but it has rank ", + get_input_partial_shape(13).rank().get_length(), + "."); + NODE_VALIDATION_CHECK(this, + get_input_element_type(13).is_dynamic() || get_input_element_type(13) == element::f32, + "Element type of `rotation_coefficients` input should be f32, but it is ", + get_input_element_type(13), + "."); + + NODE_VALIDATION_CHECK( + this, + get_input_partial_shape(14).rank().is_dynamic() || get_input_partial_shape(14).rank().get_length() == 1, + "Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ", + get_input_partial_shape(14).rank().get_length(), + "."); + NODE_VALIDATION_CHECK(this, + get_input_element_type(14).is_dynamic() || get_input_element_type(14) == element::i32, + "Element type of `rotated_block_indices` input should be i32, but it is ", + get_input_element_type(14), + "."); + } + // value head_size may be not same with key auto out_ps = get_input_partial_shape(0); const auto& key_ps = get_input_partial_shape(1); diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index 872e4539eda8df..aff95feca421e8 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -20,9 +20,12 @@ using namespace ov::op; -ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_block_indices_inputs, bool use_score_outputs) - : m_use_block_indices_inputs(use_block_indices_inputs), - m_use_score_outputs(use_score_outputs) {} +ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_indices_inputs, + bool use_score_outputs, + bool allow_cache_rotation) + : m_use_per_layer_block_indices_inputs(use_per_layer_block_indices_inputs), + m_use_score_outputs(use_score_outputs), + m_allow_cache_rotation(allow_cache_rotation) {} static std::shared_ptr setName(std::shared_ptr node, const char* name) { // Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a @@ -46,7 +49,7 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr(element::i32, PartialShape{-1}), "subsequence_begins"), setName(std::make_shared(element::i32, PartialShape{-1}), "block_indices_begins"), }; - if (!m_use_block_indices_inputs) { + if (!m_use_per_layer_block_indices_inputs) { auto block_indices = setName(std::make_shared(element::i32, PartialShape{-1}), "block_indices"); model_remaining_params.insert(model_remaining_params.begin() + 2, block_indices); } @@ -94,7 +97,9 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr position_ids; @@ -123,10 +128,13 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptroutput(0), - block_indices_inputs, + block_indices_inputs_for_each_layer, score_results, - m_use_block_indices_inputs, - m_use_score_outputs); + m_use_per_layer_block_indices_inputs, + m_use_score_outputs, + m_allow_cache_rotation, + rotation_coefficients_inputs_for_each_layer, + rotated_block_indices_inputs_for_each_layer); manager.register_pass(prev_max_seq_len, batch_dim); manager.register_pass(max_context_len); manager.register_pass(unsqueezed_position_ids->output(0)); @@ -174,14 +182,19 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptrremove_parameter(parameter); } - if (m_use_block_indices_inputs) { - model->add_parameters(block_indices_inputs); + if (m_use_per_layer_block_indices_inputs) { + model->add_parameters(block_indices_inputs_for_each_layer); } if (m_use_score_outputs) { model->add_results(score_results); } + if (m_allow_cache_rotation) { + model->add_parameters(rotation_coefficients_inputs_for_each_layer); + model->add_parameters(rotated_block_indices_inputs_for_each_layer); + } + model->add_parameters(kv_parameters); model->add_parameters(model_remaining_params); model->add_parameters({std::move(max_context_len)}); diff --git a/src/core/tests/type_prop/paged_attention.cpp b/src/core/tests/type_prop/paged_attention.cpp new file mode 100644 index 00000000000000..b1114b71ad8c8c --- /dev/null +++ b/src/core/tests/type_prop/paged_attention.cpp @@ -0,0 +1,87 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/paged_attention.hpp" + +#include + +#include "common_test_utils/test_assertions.hpp" +#include "common_test_utils/type_prop.hpp" +#include "openvino/openvino.hpp" +#include "openvino/opsets/opset13.hpp" + +using namespace ov; +using namespace testing; + +TEST(type_prop, paged_attention_static_13_inputs) { + const auto query = std::make_shared(element::f32, Shape{3, 4}); + const auto key = std::make_shared(element::f32, Shape{3, 4}); + const auto value = std::make_shared(element::f32, Shape{3, 4}); + const auto key_cache = std::make_shared(element::f32, Shape{6, 2, 5, 4}); + const auto value_cache = std::make_shared(element::f32, Shape{6, 2, 5, 4}); + const auto past_lens = std::make_shared(element::i32, Shape{5}); + const auto subsequence_begins = std::make_shared(element::i32, Shape{5}); + const auto block_indices = std::make_shared(element::i32, Shape{15}); + const auto block_indices_begins = std::make_shared(element::i32, Shape{8}); + const auto scale = std::make_shared(element::f32, Shape{}); + const auto sliding_window = std::make_shared(element::i32, Shape{}); + const auto alibi_slopes = std::make_shared(element::f32, Shape{9}); + const auto max_context_len = std::make_shared(element::i32, Shape{}); + + ov::OutputVector args = {query, + key, + value, + key_cache, + value_cache, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len}; + const auto op = std::make_shared(args); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), (Shape{3, 4})); +} + +TEST(type_prop, paged_attention_static_15_inputs) { + const auto query = std::make_shared(element::f32, Shape{3, 4}); + const auto key = std::make_shared(element::f32, Shape{3, 4}); + const auto value = std::make_shared(element::f32, Shape{3, 4}); + const auto key_cache = std::make_shared(element::f32, Shape{6, 2, 5, 4}); + const auto value_cache = std::make_shared(element::f32, Shape{6, 2, 5, 4}); + const auto past_lens = std::make_shared(element::i32, Shape{5}); + const auto subsequence_begins = std::make_shared(element::i32, Shape{5}); + const auto block_indices = std::make_shared(element::i32, Shape{15}); + const auto block_indices_begins = std::make_shared(element::i32, Shape{8}); + const auto scale = std::make_shared(element::f32, Shape{}); + const auto sliding_window = std::make_shared(element::i32, Shape{}); + const auto alibi_slopes = std::make_shared(element::f32, Shape{9}); + const auto max_context_len = std::make_shared(element::i32, Shape{}); + + const auto rotation_coefficients = std::make_shared(element::f32, Shape{12}); + const auto rotated_block_indices = std::make_shared(element::i32, Shape{3}); + + ov::OutputVector args = {query, + key, + value, + key_cache, + value_cache, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + rotation_coefficients, + rotated_block_indices}; + + const auto op = std::make_shared(args); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 4})); +} diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 4e013a004d29f9..fbe02036b0a268 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -9,6 +9,7 @@ #include #include "openvino/core/type/element_type.hpp" #include "utils/plain_tensor.hpp" +#include "common.hpp" namespace ov { namespace Extensions { @@ -53,4 +54,4 @@ void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale } // namespace XARCH } // namespace Cpu } // namespace Extensions -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp new file mode 100644 index 00000000000000..28f7d9c9ec0ca0 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp @@ -0,0 +1,234 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "common.hpp" +#include "openvino/openvino.hpp" + +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# include +#endif + +#if defined(HAVE_AVX512F) +template +inline static void rotate_kv_cache_chunk_avx512(CT* current_x_values_ptr, + CT* current_y_values_ptr, + float* current_rotation_coeffts_cos_ptr, + float* current_rotation_coeffts_sin_ptr, + size_t num_vectorized_elements_per_iteration, + bool is_underutilizing) { + using namespace ov::Extensions::Cpu::XARCH; + + auto result_x = _mm512_setzero_ps(); + auto result_y = _mm512_setzero_ps(); + + auto coeffts_cos = _mm512_undefined_ps(); + auto coeffts_sin = _mm512_undefined_ps(); + + auto cache_values_x = _mm512_undefined_ps(); + auto cache_values_y = _mm512_undefined_ps(); + + if (!is_underutilizing) { + coeffts_cos = mm512_uni_loadu_ps(current_rotation_coeffts_cos_ptr); + coeffts_sin = mm512_uni_loadu_ps(current_rotation_coeffts_sin_ptr); + + cache_values_x = mm512_uni_loadu_ps(current_x_values_ptr); + cache_values_y = mm512_uni_loadu_ps(current_y_values_ptr); + } else { + coeffts_cos = mm512_uni_loadu_tail_ps(current_rotation_coeffts_cos_ptr, num_vectorized_elements_per_iteration); + coeffts_sin = mm512_uni_loadu_tail_ps(current_rotation_coeffts_sin_ptr, num_vectorized_elements_per_iteration); + + cache_values_x = mm512_uni_loadu_tail_ps(current_x_values_ptr, num_vectorized_elements_per_iteration); + cache_values_y = mm512_uni_loadu_tail_ps(current_y_values_ptr, num_vectorized_elements_per_iteration); + } + + result_x = _mm512_fmadd_ps(cache_values_x, coeffts_cos, result_x); + result_x = _mm512_fnmadd_ps(cache_values_y, coeffts_sin, result_x); // negative multiply-add + + result_y = _mm512_fmadd_ps(cache_values_x, coeffts_sin, result_y); + result_y = _mm512_fmadd_ps(cache_values_y, coeffts_cos, result_y); + + if (!is_underutilizing) { + mm512_uni_storeu_ps(current_x_values_ptr, result_x); + mm512_uni_storeu_ps(current_y_values_ptr, result_y); + } else { + mm512_uni_storeu_tail_ps(current_x_values_ptr, result_x, num_vectorized_elements_per_iteration); + mm512_uni_storeu_tail_ps(current_y_values_ptr, result_y, num_vectorized_elements_per_iteration); + } +} +#endif + +#if defined(HAVE_AVX2) +template +inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, + CT* current_y_values_ptr, + float* current_rotation_coeffts_cos_ptr, + float* current_rotation_coeffts_sin_ptr, + size_t num_vectorized_elements_per_iteration, + size_t is_underutilizing) { + using namespace ov::Extensions::Cpu::XARCH; + + auto result_x = _mm256_setzero_ps(); + auto result_y = _mm256_setzero_ps(); + + auto coeffts_cos = _mm256_undefined_ps(); + auto coeffts_sin = _mm256_undefined_ps(); + + auto cache_values_x = _mm256_undefined_ps(); + auto cache_values_y = _mm256_undefined_ps(); + + if (!is_underutilizing) { + coeffts_cos = mm256_uni_loadu_ps(current_rotation_coeffts_cos_ptr); + coeffts_sin = mm256_uni_loadu_ps(current_rotation_coeffts_sin_ptr); + + cache_values_x = mm256_uni_loadu_ps(current_x_values_ptr); + cache_values_y = mm256_uni_loadu_ps(current_y_values_ptr); + } else { + coeffts_cos = mm256_uni_loadu_tail_ps(current_rotation_coeffts_cos_ptr, num_vectorized_elements_per_iteration); + coeffts_sin = mm256_uni_loadu_tail_ps(current_rotation_coeffts_sin_ptr, num_vectorized_elements_per_iteration); + + cache_values_x = mm256_uni_loadu_tail_ps(current_x_values_ptr, num_vectorized_elements_per_iteration); + cache_values_y = mm256_uni_loadu_tail_ps(current_y_values_ptr, num_vectorized_elements_per_iteration); + } + + result_x = _mm256_fmadd_ps(cache_values_x, coeffts_cos, result_x); + result_x = _mm256_fnmadd_ps(cache_values_y, coeffts_sin, result_x); // negative multiply-add + + result_y = _mm256_fmadd_ps(cache_values_x, coeffts_sin, result_y); + result_y = _mm256_fmadd_ps(cache_values_y, coeffts_cos, result_y); + + if (!is_underutilizing) { + mm256_uni_storeu_ps(current_x_values_ptr, result_x); + mm256_uni_storeu_ps(current_y_values_ptr, result_y); + } else { + mm256_uni_storeu_tail_ps(current_x_values_ptr, result_x, num_vectorized_elements_per_iteration); + mm256_uni_storeu_tail_ps(current_y_values_ptr, result_y, num_vectorized_elements_per_iteration); + } +} +#endif + +template +inline static void rotate_kv_cache_block_hw(CT* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { +#if !defined(HAVE_AVX2) && !defined(HAVE_AVX512F) + OPENVINO_THROW("host CPU must support either AVX2 or AVX512 instructions"); +#else + bool is_underutilizing = false; + +# if defined(HAVE_AVX512F) + constexpr size_t vec_len_in_f32_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx512; +# else // HAVE_AVX2 + constexpr size_t vec_len_in_f32_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx2; +# endif // defined(HAVE_AVX512F) + + size_t num_processed_elements_per_iteration = + 2 * vec_len_in_f32_elts; // implementations act on pairs of cache values at once using separate registers, each + // elt is expanded to f32 on load + size_t num_iterations = embedding_size / num_processed_elements_per_iteration; + + if (embedding_size >= num_processed_elements_per_iteration) { + OPENVINO_ASSERT(!(num_processed_elements_per_iteration % vec_len_in_f32_elts)); + } else { + is_underutilizing = true; + OPENVINO_ASSERT(!(embedding_size % 2)); + num_processed_elements_per_iteration = embedding_size; + num_iterations = 1; + } + + CT* current_cache_element_ptr = cache_block_ptr; + + for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { + // the rotation coefficients are taken to be the same for all heads + float* current_rotation_coeffts_ptr = block_rotation_coefficients_ptr; + for (size_t tok_idx = 0; tok_idx < block_size; + tok_idx++, current_cache_element_ptr += embedding_size, current_rotation_coeffts_ptr += embedding_size) { + CT* current_x_values_ptr = current_cache_element_ptr; + CT* current_y_values_ptr = current_cache_element_ptr + embedding_size / 2; + + float* current_rotation_coeffts_cos_ptr = current_rotation_coeffts_ptr; + float* current_rotation_coeffts_sin_ptr = current_rotation_coeffts_ptr + embedding_size / 2; + + for (size_t iter_idx = 0; iter_idx < num_iterations; iter_idx++, + current_x_values_ptr += vec_len_in_f32_elts, + current_y_values_ptr += vec_len_in_f32_elts, + current_rotation_coeffts_cos_ptr += vec_len_in_f32_elts, + current_rotation_coeffts_sin_ptr += vec_len_in_f32_elts) { +# if defined(HAVE_AVX512F) + rotate_kv_cache_chunk_avx512(current_x_values_ptr, + current_y_values_ptr, + current_rotation_coeffts_cos_ptr, + current_rotation_coeffts_sin_ptr, + num_processed_elements_per_iteration / 2, + is_underutilizing); +# else // HAVE_AVX2 + rotate_kv_cache_chunk_avx2(current_x_values_ptr, + current_y_values_ptr, + current_rotation_coeffts_cos_ptr, + current_rotation_coeffts_sin_ptr, + num_processed_elements_per_iteration / 2, + is_underutilizing); +# endif // defined(HAVE_AVX512F) + } + } + } +#endif // !defined(HAVE_AVX512F) && !defined(HAVE_AVX2F) +} + +template +inline static void rotate_kv_cache_block_sw(CT* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { + for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { + for (size_t tok_idx = 0; tok_idx < block_size; tok_idx++) { + size_t token_offset = embedding_size * tok_idx; + CT* token_embedding_data_start_in_cache = + cache_block_ptr + head_idx * embedding_size * block_size + embedding_size * tok_idx; + float* token_data_start_in_rotation_coefficients = block_rotation_coefficients_ptr + token_offset; + for (size_t embedding_pair_idx = 0; embedding_pair_idx < embedding_size / 2; embedding_pair_idx++) { + // NB: below is the llama-style rotation (x-like values are in the first half of the embedding vector, + // y-like values are in the second half), which is different from the original RoFormer style (x- and y- + // values are interleaved), but still preserves the relative positional encoding property + CT* cache_value_0_ptr = token_embedding_data_start_in_cache + embedding_pair_idx; + CT* cache_value_1_ptr = cache_value_0_ptr + (embedding_size / 2); + + float rotation_value_cos = token_data_start_in_rotation_coefficients[embedding_pair_idx]; + float rotation_value_sin = + token_data_start_in_rotation_coefficients[embedding_pair_idx + (embedding_size / 2)]; + + CT cache_value_0 = *cache_value_0_ptr; + CT cache_value_1 = *cache_value_1_ptr; + + *cache_value_0_ptr = cache_value_0 * rotation_value_cos - cache_value_1 * rotation_value_sin; + *cache_value_1_ptr = cache_value_0 * rotation_value_sin + cache_value_1 * rotation_value_cos; + } + } + } +} + +template +inline static void rotate_kv_cache_block(CT* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { +#if defined(HAVE_AVX512F) || defined(HAVE_AVX2) + rotate_kv_cache_block_hw(cache_block_ptr, block_rotation_coefficients_ptr, num_heads, block_size, embedding_size); +#else + rotate_kv_cache_block_sw(cache_block_ptr, block_rotation_coefficients_ptr, num_heads, block_size, embedding_size); +#endif // defined(HAVE_AVX512F) || defined(HAVE_AVX2) +} + +template <> +inline void rotate_kv_cache_block(uint8_t* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { + OPENVINO_THROW("cache rotation is not implemented for INT8"); +} diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index 2956c8a6a6b5b8..4c363a29c0db02 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -132,20 +132,34 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); } #endif +#if defined(HAVE_AVX2) +#endif + #ifdef HAVE_AVX2 + inline __m128i get_8bit_tail_mask_for_16bit_elts(size_t num_16bit_tail_elts) { + // num_tail_elts may take from 0 to 8 + static int8_t masks[9][16] = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}}; + return _mm_loadu_si128(reinterpret_cast<__m128i*>(masks[num_16bit_tail_elts])); + } inline __m256i get_mask(int N7) { - static __m256i mask[] = { - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, 0), - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1), - _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1), - _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1), - _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1), - _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1), - _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1), - _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1), - _mm256_set_epi32(-1, -1, -1, -1, -1, -1, -1, -1), - }; - return _mm256_loadu_si256(&mask[N7]); + static int32_t masks[9][8] = {{0, 0, 0, 0, 0, 0, 0, 0}, + {-1, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1}}; + return _mm256_loadu_si256(reinterpret_cast<__m256i*>(masks[N7])); } // load addr to __m256 reg @@ -189,7 +203,7 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); _mm256_storeu_ps(a, v); } - inline void mm256_uni_storeu_ps(ov::bfloat16 *addr, __m256 xps) { + inline __m128i __convert_avx2_packed_float_to_packed_ov_bfloat16(__m256 xps) { __m256i xpi32 = _mm256_castps_si256(xps); __m256i nan = _mm256_set1_epi32(0xffff); __m256i mask = _mm256_castps_si256(_mm256_cmp_ps(xps, xps, _CMP_ORD_Q)); @@ -202,6 +216,11 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); x = _mm256_packus_epi32(x, x); x = _mm256_permute4x64_epi64(x, 0xd8); __m128i bf16_o = _mm256_extractf128_si256(x, 0); + return bf16_o; + } + + inline void mm256_uni_storeu_ps(ov::bfloat16* addr, __m256 xps) { + __m128i bf16_o = __convert_avx2_packed_float_to_packed_ov_bfloat16(xps); _mm_storeu_si128(reinterpret_cast<__m128i *>(addr), bf16_o); } @@ -212,10 +231,22 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); // store __m256 to addr inline void mm256_uni_storeu_tail_ps(float *addr, __m256 v, size_t count) { - const auto mask = get_mask(count); + auto mask = get_mask(count); return _mm256_maskstore_ps(addr, mask, v); } + inline void mm256_uni_storeu_tail_ps(ov::float16* addr, __m256 v, size_t count) { + auto mask = get_8bit_tail_mask_for_16bit_elts(count); + __m128i vec_f16 = _mm256_cvtps_ph(v, 0); + return _mm_maskmoveu_si128(vec_f16, mask, reinterpret_cast(addr)); + } + + inline void mm256_uni_storeu_tail_ps(ov::bfloat16* addr, __m256 v, size_t count) { + auto mask = get_8bit_tail_mask_for_16bit_elts(count); + __m128i bf16_o = __convert_avx2_packed_float_to_packed_ov_bfloat16(v); + return _mm_maskmoveu_si128(bf16_o, mask, reinterpret_cast(addr)); + } + inline void hsum(__m256& x) { __m256 y; // x: 0 1 2 3 4 5 6 7 y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index bef34881ca41bc..8a3f8f8d5dfdba 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -13,19 +13,20 @@ # include #endif -#include "openvino/core/type/bfloat16.hpp" -#include "openvino/core/type/float16.hpp" -#include "openvino/core/parallel.hpp" +#include "attn_memcpy.hpp" +#include "attn_quant.hpp" +#include "attn_quant_kernel.hpp" +#include "cache_rotation.hpp" +#include "common.hpp" #include "executor_pa.hpp" #include "executor_pa_common.hpp" -#include "common.hpp" -#include "attn_quant_kernel.hpp" +#include "nodes/kernels/x64/brgemm_kernel.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/type/float16.hpp" #include "softmax_kernel.hpp" #include "transpose_kernel.hpp" #include "utils/plain_tensor.hpp" -#include "attn_memcpy.hpp" -#include "attn_quant.hpp" -#include "nodes/kernels/x64/brgemm_kernel.hpp" namespace ov { namespace Extensions { @@ -769,6 +770,27 @@ static void pack_32NxK(float* dst, T* src, float* tmp, size_t N, size_t K, size_ OPENVINO_THROW("pack_32NxK: should not be called."); } +template +void rotate_kv_cache(PlainTensor& key_cache, + const PlainTensor& rotation_coefficients, + const PlainTensor& rotated_block_indices) { + size_t num_rotated_blocks = rotated_block_indices.size(0); + size_t num_blocks_in_total = key_cache.size(0); + size_t block_size = key_cache.size(2); + int32_t* rotated_block_indices_data = rotated_block_indices.ptr(); + size_t num_heads = key_cache.size(1); // H; + size_t embedding_size = key_cache.size(3); // S; + size_t head_chunk_size = block_size * embedding_size; + + for (size_t i = 0; i < num_rotated_blocks; i++) { + size_t rotated_block_index = *(rotated_block_indices_data + i); + OPENVINO_ASSERT(rotated_block_index < num_blocks_in_total); + float* rotation_coefficient_block_data = rotation_coefficients.ptr() + i * head_chunk_size; + KVCACHE_TYPE* cache_block_ptr = key_cache.ptr(rotated_block_index); + rotate_kv_cache_block(cache_block_ptr, rotation_coefficient_block_data, num_heads, block_size, embedding_size); + } +} + template struct MHAHelper { // initialize once @@ -1137,16 +1159,19 @@ struct MHAHelper { cvt_copy(output_emb.ptr(pq, h * _SV), _output.ptr(ithr, pq, h), _SV); } - // compute one token, loop along batch, head dimensions and kv_len, it's special for very long kv_len with small batch tokens. - // It will assume NO mixture execution of first and second token. - // all tensors such as query... have batch dimension which is DIFFERENT from above + // compute one token, loop along batch, head dimensions and kv_len, it's special for very long kv_len with small + // batch tokens. It will assume NO mixture execution of first and second token. all tensors such as query... have + // batch dimension which is DIFFERENT from above // query: [B, H, L, S] - // present_*: [block_number, H, 32, S] + // key_cache: [block_number, H, _block_size, S] + // value_cache: [block_number, H, _block_size, Sv] // output_emb: [B, L, H * S] + // rotated_block_indices: [num_rotated_blocks] + // rotation_coefficients: [num_rotated_blocks, _block_size, S] // 3 loops along batch, head, kv cache length dimensions void exec_loop_bhl(const PlainTensor& query, - const PlainTensor& present_key, - const PlainTensor& present_value, + PlainTensor& key_cache, + PlainTensor& value_cache, const PlainTensor& output_emb, const PlainTensor& output_score, size_t max_context_len, @@ -1154,7 +1179,9 @@ struct MHAHelper { const PlainTensor& subsequence_begins, const PlainTensor& block_indices, const PlainTensor& block_indices_begins, - const PlainTensor& alibi_slopes) { + const PlainTensor& alibi_slopes, + const PlainTensor& rotation_coefficients, + const PlainTensor& rotated_block_indices) { auto B = past_lens.size(0); auto q_len = query.size(2); auto kv_len_in_blocks = div_up(max_context_len, _block_size); @@ -1162,6 +1189,10 @@ struct MHAHelper { // aligned to cache line (64bytes=16*sizeof(float)) to avoid false sharing _weight_bhl.resize({B, _H, q_len, rnd_up(max_context_len, std::max(_block_size, size_t{16}))}); + if (rotation_coefficients) { + rotate_kv_cache(key_cache, rotation_coefficients, rotated_block_indices); + } + parallel_for3d_dynamic(B, kv_len_in_blocks, _Hk, [&](size_t b, size_t pk_in_blocks, size_t hk) { auto context_len = static_cast(past_lens.ptr()[b]) + 1; // kv_len must be valid @@ -1172,16 +1203,20 @@ struct MHAHelper { _gemv->tile_config(); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { - (*_gemv)(query.ptr(b, h, pq), present_key.ptr(block_number, hk), - _weight_bhl.ptr(b, h, pq) + pk); + (*_gemv)(query.ptr(b, h, pq), + key_cache.ptr(block_number, hk), + _weight_bhl.ptr(b, h, pq) + pk); } } _gemv->tile_release(); } else { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { - dot_product_block(query.ptr(b, h, pq), present_key.ptr(block_number, hk), - _weight_bhl.ptr(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk)); + dot_product_block(query.ptr(b, h, pq), + key_cache.ptr(block_number, hk), + _weight_bhl.ptr(b, h, pq) + pk, + _S, + std::min(_block_size, context_len - pk)); } } } @@ -1236,7 +1271,7 @@ struct MHAHelper { // kv_len must be valid if (pv < context_len) { auto block_number = block_indices.ptr()[block_indices_begins.ptr()[b] + pv_in_blocks]; - auto* v = present_value.ptr(block_number, hk); + auto* v = value_cache.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { attn_acc_value_block(_output_bhl.ptr(ithr, b, pq, h), @@ -1362,7 +1397,7 @@ struct MHA { // one loop to handle first and second tokens void exec_loop_mixed(const PlainTensor& q, - const PlainTensor& k_cache, + PlainTensor& k_cache, const PlainTensor& v_cache, const PlainTensor& output_emb, const PlainTensor& output_score, @@ -1371,7 +1406,9 @@ struct MHA { const PlainTensor& subsequence_begins, const PlainTensor& block_indices, const PlainTensor& block_indices_begins, - const PlainTensor& alibi_slopes) { + const PlainTensor& alibi_slopes, + const PlainTensor& rotation_coefficients, + const PlainTensor& rotated_block_indices) { auto Hk = v_cache.m_dims[1]; constexpr bool q_is_xf16 = one_of(precision_of::value, ov::element::bf16, ov::element::f16); @@ -1379,6 +1416,10 @@ struct MHA { auto attn_work_count = _workitems.attn_work_size(); auto reorder_work_count = _workitems.reorder_work_size(); + if (rotation_coefficients) { + rotate_kv_cache(k_cache, rotation_coefficients, rotated_block_indices); + } + // buffer for transpose and repack _helper.init_reorder_buffers(_workitems.get_reorder_max_batch_size(), div_up(_workitems.get_reorder_max_kv_len(), _helper._block_size)); @@ -1493,7 +1534,9 @@ struct MHA { const PlainTensor& subsequence_begins, const PlainTensor& block_indices, const PlainTensor& block_indices_begins, - const PlainTensor& alibi_slopes) { + const PlainTensor& alibi_slopes, + const PlainTensor& rotation_coefficients, + const PlainTensor& rotated_block_indices) { _workitems.reset(query, past_lens, subsequence_begins, _helper._block_size); if (output_score) _helper.init_score_buffers(past_lens, subsequence_begins); @@ -1501,11 +1544,33 @@ struct MHA { auto nthr = static_cast(parallel_get_max_threads()); if (past_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) { - exec_loop_mixed(query, present_key, present_value, output_emb, output_score, max_context_len, past_lens, subsequence_begins, - block_indices, block_indices_begins, alibi_slopes); + exec_loop_mixed(query, + present_key, + present_value, + output_emb, + output_score, + max_context_len, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + alibi_slopes, + rotation_coefficients, + rotated_block_indices); } else { - _helper.exec_loop_bhl(query, present_key, present_value, output_emb, output_score, max_context_len, past_lens, subsequence_begins, - block_indices, block_indices_begins, alibi_slopes); + _helper.exec_loop_bhl(query, + present_key, + present_value, + output_emb, + output_score, + max_context_len, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + alibi_slopes, + rotation_coefficients, + rotated_block_indices); } } }; @@ -1518,9 +1583,25 @@ struct AttentionExecutor : public PagedAttentionExecutor { AttentionExecutor() : _kernel(_helper) {} - void init(const std::vector& inputs, const std::vector& outputs, PlainTensor& q, PlainTensor& k, PlainTensor& v, PlainTensor& k_cache, - PlainTensor& v_cache, PlainTensor& past_lens, PlainTensor& subsequence_begins, PlainTensor& block_indices, PlainTensor& block_indices_begins, - float& scale, size_t& sliding_window, PlainTensor& alibi_slopes, size_t& max_context_len, PlainTensor& output_emb, PlainTensor& output_score) { + void init(const std::vector& inputs, + const std::vector& outputs, + PlainTensor& q, + PlainTensor& k, + PlainTensor& v, + PlainTensor& k_cache, + PlainTensor& v_cache, + PlainTensor& past_lens, + PlainTensor& subsequence_begins, + PlainTensor& block_indices, + PlainTensor& block_indices_begins, + float& scale, + size_t& sliding_window, + PlainTensor& alibi_slopes, + size_t& max_context_len, + PlainTensor& rotation_coefficients, + PlainTensor& rotated_block_indices, + PlainTensor& output_emb, + PlainTensor& output_score) { q.reset(inputs[ID_Q]); // [B_token, H * S] k.reset(inputs[ID_K]); v.reset(inputs[ID_V]); @@ -1535,6 +1616,16 @@ struct AttentionExecutor : public PagedAttentionExecutor { if (!inputs[ID_ALIBI_SLOPES]->getShape().hasZeroDims()) alibi_slopes.reset(inputs[ID_ALIBI_SLOPES]); max_context_len = static_cast(*inputs[ID_MAX_CONTEXT_LEN]->getDataAs()); + + size_t inputs_size = inputs.size(); + if (inputs_size > ID_ROTATION_COEFFICIENTS) { + OPENVINO_ASSERT(inputs_size >= ID_ROTATED_BLOCK_INDICES); + if (!inputs[ID_ROTATION_COEFFICIENTS]->getShape().hasZeroDims()) + rotation_coefficients.reset(inputs[ID_ROTATION_COEFFICIENTS]); + if (!inputs[ID_ROTATED_BLOCK_INDICES]->getShape().hasZeroDims()) + rotated_block_indices.reset(inputs[ID_ROTATED_BLOCK_INDICES]); + } + output_emb.reset(outputs[0]); if (outputs.size() == 2) output_score.reset(outputs[1]); @@ -1576,6 +1667,12 @@ struct AttentionExecutor : public PagedAttentionExecutor { if (alibi_slopes) { alibi_slopes.assert_dims({H}); } + + if (rotated_block_indices) { + // Only K entries are needed to be rotated, since position is encoded at the Q^T @ (effective_RoPE_matrix) @ + // K matrix multiplication + rotation_coefficients.assert_dims({S * rotated_block_indices.size(0) * block_size}); + } output_emb.assert_dims({B_token, H * SV}); output_emb = output_emb.reshape({B_token, 1, H * SV}); @@ -1617,15 +1714,45 @@ struct AttentionExecutor : public PagedAttentionExecutor { size_t sliding_window; PlainTensor alibi_slopes; size_t max_context_len; + PlainTensor rotation_coefficients; + PlainTensor rotated_block_indices; PlainTensor output_emb; PlainTensor output_score; - init(inputs, outputs, q, k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins, - scale, sliding_window, alibi_slopes, max_context_len, output_emb, output_score); + init(inputs, + outputs, + q, + k, + v, + k_cache, + v_cache, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + rotation_coefficients, + rotated_block_indices, + output_emb, + output_score); concat_pastkv(k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins); - _kernel(q, k_cache, v_cache, output_emb, output_score, max_context_len, past_lens, subsequence_begins, block_indices, - block_indices_begins, alibi_slopes); + _kernel(q, + k_cache, + v_cache, + output_emb, + output_score, + max_context_len, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + alibi_slopes, + rotation_coefficients, + rotated_block_indices); } }; #endif @@ -1677,4 +1804,4 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ } // namespace XARCH } // namespace Cpu } // namespace Extensions -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp index bc21457a3285b4..b2724a2bf79569 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp @@ -33,6 +33,8 @@ struct PagedAttentionExecutor { static const size_t ID_SLIDING_WINDOW = 10; // [] static const size_t ID_ALIBI_SLOPES = 11; // [H|0], float static const size_t ID_MAX_CONTEXT_LEN = 12; // [] + static const size_t ID_ROTATION_COEFFICIENTS = 13; // [num_rotated_blocks * block_size || 0], float + static const size_t ID_ROTATED_BLOCK_INDICES = 14; // [num_rotated_blocks], float virtual void execute(const std::vector& inputs, const std::vector outputs) = 0; virtual ~PagedAttentionExecutor() = default; }; diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index b9666388490f74..97be1eea79632c 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -80,7 +80,8 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { config.inConfs[PagedAttentionExecutor::ID_V].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( rtPrecision, getInputShapeAtPort(PagedAttentionExecutor::ID_V))); - OPENVINO_ASSERT(orgInputNumber == 13, "The input number of PagedAttention should be 13."); + OPENVINO_ASSERT(orgInputNumber == 15 || orgInputNumber == 13, + "The input number of PagedAttention should be 13 or 15."); // kvcache, float, [] auto past_kv_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); config.inConfs[PagedAttentionExecutor::ID_KCACHE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( @@ -112,6 +113,19 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { config.inConfs[PagedAttentionExecutor::ID_MAX_CONTEXT_LEN].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_MAX_CONTEXT_LEN))); + if (orgInputNumber == 15) { + // rotation_coefficients, float, [num_rotated_blocks * block_size || 0] + config.inConfs[PagedAttentionExecutor::ID_ROTATION_COEFFICIENTS].setMemDesc( + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(ov::element::f32, + getInputShapeAtPort(PagedAttentionExecutor::ID_ROTATION_COEFFICIENTS))); + // rotated_block_indices, int, [num_rotated_blocks || 0] + config.inConfs[PagedAttentionExecutor::ID_ROTATED_BLOCK_INDICES].setMemDesc( + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(ov::element::i32, + getInputShapeAtPort(PagedAttentionExecutor::ID_ROTATED_BLOCK_INDICES))); + } + config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( rtPrecision, getOutputShapeAtPort(0))); config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( diff --git a/src/plugins/intel_cpu/tests/unit/CMakeLists.txt b/src/plugins/intel_cpu/tests/unit/CMakeLists.txt index 63441b504735b0..81645f4fc87553 100644 --- a/src/plugins/intel_cpu/tests/unit/CMakeLists.txt +++ b/src/plugins/intel_cpu/tests/unit/CMakeLists.txt @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 # +add_subdirectory(vectorized) + set(TARGET_NAME ov_cpu_unit_tests) if(BUILD_SHARED_LIBS) @@ -52,6 +54,8 @@ ov_add_test_target( $/include EXCLUDED_SOURCE_PATHS ${EXCLUDED_SOURCE_PATHS_FOR_UNIT_TEST} + ${CMAKE_CURRENT_SOURCE_DIR}/vectorized + OBJECT_FILES ${OBJ_LIB} LINK_LIBRARIES @@ -78,6 +82,7 @@ if (ENABLE_SNIPPETS_LIBXSMM_TPP) target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $) endif() + # LTO set_target_properties(${TARGET_NAME} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE ${ENABLE_LTO}) diff --git a/src/plugins/intel_cpu/tests/unit/vectorized/CMakeLists.txt b/src/plugins/intel_cpu/tests/unit/vectorized/CMakeLists.txt new file mode 100644 index 00000000000000..f6cd449dbd2f72 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/vectorized/CMakeLists.txt @@ -0,0 +1,70 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +set(TARGET_NAME ov_cpu_unit_tests_vectorized) + +if(BUILD_SHARED_LIBS) + set (OBJ_LIB $) +endif() + +if(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + ov_add_compiler_flags(/wd5051) +endif() + +if(NOT X86_64) + list(APPEND EXCLUDED_SOURCE_PATHS_FOR_UNIT_TEST + ${CMAKE_CURRENT_SOURCE_DIR}/paged_attn_cache_rotation.cpp) +endif() + + +ov_add_test_target( + NAME ${TARGET_NAME} + ROOT ${CMAKE_CURRENT_SOURCE_DIR} + INCLUDES + PUBLIC + $/src + $/src/nodes + $ + PRIVATE + $/include + EXCLUDED_SOURCE_PATHS + ${EXCLUDED_SOURCE_PATHS_FOR_UNIT_TEST} + OBJECT_FILES + ${OBJ_LIB} + LINK_LIBRARIES + gtest + gtest_main + dnnl + gmock + openvino_runtime_s + unit_test_utils + snippets_test_utils + ADD_CPPLINT + LABELS + OV UNIT CPU +) + + +if (X86_64) + ov_avx2_optimization_flags(avx2_flags) + ov_avx512_optimization_flags(avx512_flags) + message("VSHAMPOR: added optimization flags") + + target_compile_options(${TARGET_NAME} PRIVATE "${avx2_flags};${avx512_flags}") + target_compile_definitions(${TARGET_NAME} PRIVATE HAVE_AVX2 HAVE_AVX512F) +endif() + + +if (WIN32) + # Prevents defining min/max as macros + target_compile_definitions(${TARGET_NAME} PRIVATE NOMINMAX) +endif() + +target_include_directories(${TARGET_NAME} SYSTEM PRIVATE + $) + +target_include_directories(${TARGET_NAME} SYSTEM PRIVATE + $/src/common + $/src/cpu + $/include) diff --git a/src/plugins/intel_cpu/tests/unit/vectorized/paged_attn_cache_rotation.cpp b/src/plugins/intel_cpu/tests/unit/vectorized/paged_attn_cache_rotation.cpp new file mode 100644 index 00000000000000..980bf6b0de18d8 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/vectorized/paged_attn_cache_rotation.cpp @@ -0,0 +1,484 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include +#include +#include + +// the includes in the block below are necessary in order for the common.hpp header to be +// instantiated correctly +#include +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# include +#endif +#include "kernels/scaled_attn/common.hpp" +#include "nodes/kernels/scaled_attn/cache_rotation.hpp" +#include "perf_count.h" +#include "utils/plain_tensor.hpp" + +using namespace ov::intel_cpu; + +template +using Rank2Matrix = std::vector>; + +template +using Rank3Matrix = std::vector>>; + +// Expected layout: [block_size, embedding_size] +template +std::shared_ptr get_block_memory(size_t block_size, size_t embedding_size, const Rank2Matrix& init_values) { + auto mem = std::shared_ptr(new T[block_size * embedding_size]); + if (!init_values.empty()) { + assert(init_values.size() == block_size); + assert(init_values[0].size() == embedding_size); + for (size_t i = 0; i < block_size; i++) { + for (size_t j = 0; j < embedding_size; j++) { + mem[i * embedding_size + j] = init_values[i][j]; + } + } + } + return mem; +} + +// Expected layout: [num_heads, block_size, embedding_size] +template +std::shared_ptr get_block_memory(size_t num_heads, + size_t block_size, + size_t embedding_size, + const Rank3Matrix& init_values) { + auto mem = std::shared_ptr(new T[num_heads * block_size * embedding_size]); + if (!init_values.empty()) { + assert(init_values.size() == num_heads); + assert(init_values[0].size() == block_size); + assert(init_values[0][0].size() == embedding_size); + for (size_t i = 0; i < num_heads; i++) { + for (size_t j = 0; j < block_size; j++) { + for (size_t k = 0; k < embedding_size; k++) { + mem[i * embedding_size * block_size + j * embedding_size + k] = init_values[i][j][k]; + } + } + } + } + return mem; +} + +template +Rank3Matrix get_matrix_from_mem(std::shared_ptr mem_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { + Rank3Matrix retval(num_heads); + for (size_t i = 0; i < num_heads; i++) { + retval[i].resize(block_size); + for (size_t j = 0; j < block_size; j++) { + retval[i][j].resize(embedding_size); + } + } + for (size_t i = 0; i < num_heads; i++) { + for (size_t j = 0; j < block_size; j++) { + for (size_t k = 0; k < embedding_size; k++) { + retval[i][j][k] = mem_ptr[block_size * embedding_size * i + embedding_size * j + k]; + } + } + } + return retval; +} + +template +void compare_with_tolerance(const Rank3Matrix& test_data, const Rank3Matrix& ref_data, T abs_err) { + ASSERT_EQ(test_data.size(), ref_data.size()); + ASSERT_GT(test_data.size(), 0); + + ASSERT_EQ(test_data[0].size(), ref_data[0].size()); + ASSERT_GT(test_data[0].size(), 0); + + ASSERT_EQ(test_data[0][0].size(), ref_data[0][0].size()); + ASSERT_GT(test_data[0][0].size(), 0); + + for (size_t i = 0; i < test_data.size(); i++) { + for (size_t j = 0; j < test_data[0].size(); j++) { + for (size_t k = 0; k < test_data[0][0].size(); k++) { + T diff = test_data[i][j][k] - ref_data[i][j][k]; + if ((diff > abs_err) || (diff < -abs_err)) { + ADD_FAILURE() << std::setprecision(8) << "diff " << diff << " exceeding atol " << abs_err + << " at idx [" << i << ";" << j << ";" << k << "] --- test " << test_data[i][j][k] + << ", ref " << ref_data[i][j][k]; + } + } + } + } +} + +template +static T get_tolerance() { + return T{}; +} + +template <> +float get_tolerance() { + return 1e-6; +} + +template <> +ov::float16 get_tolerance() { + return ov::float16{5e-3}; +} + +template <> +ov::bfloat16 get_tolerance() { + return ov::bfloat16{4e-2}; +} + +template +class CacheRotationKernelTest : public ::testing::Test { +public: + void SetUp() override { + Rank3Matrix values_before_rotation = { + { + {1.0, 1.0, 1.0, 1.0}, + {1.0, 1.0, 1.0, 1.0}, + {1.0, 1.0, 1.0, 1.0}, + {1.0, 1.0, 1.0, 1.0}, + }, + { + {-2.0, -2.0, -2.0, -2.0}, + {2.0, 2.0, 2.0, 2.0}, + {-1.0, 2.0, -3.0, 4.0}, + {2.0, 2.0, 2.0, 2.0}, + }, + }; + cache_mem_ptr = get_block_memory(num_heads, block_size, embedding_size, values_before_rotation); + + Rank2Matrix rotation_values = { + {0.5, 0.70710678, 0.86602540, -0.70710678}, + {0.86602540, 1.0, 0.5, 0.0}, + {-0.70710678, 0.0, 0.70710678, 1.0}, + {0.0, 0.6, -1.0, -0.8}, + }; + + rotation_coefficients_mem_ptr = get_block_memory(block_size, embedding_size, rotation_values); + } + size_t num_heads = 2; + size_t block_size = 4; + size_t embedding_size = 4; + std::shared_ptr cache_mem_ptr; + std::shared_ptr rotation_coefficients_mem_ptr; + Rank3Matrix ref_values_after_rotation = { + { + {-0.36602540, 1.41421356, 1.36602540, 0.00000000}, + {0.36602540, 1.00000000, 1.36602540, 1.00000000}, + {-1.41421356, -1.00000000, 0.00000000, 1.00000000}, + {1.00000000, 1.40000000, -1.00000000, -0.20000000}, + }, + { + {0.73205081, -2.82842712, -2.73205081, 0.00000000}, + {0.73205081, 2.00000000, 2.73205081, 2.00000000}, + {2.82842712, -4.00000000, 1.41421356, 2.00000000}, + {2.00000000, 2.80000000, -2.00000000, -0.40000000}, + }, + }; + + void test_block_hw_vs_sw(size_t num_heads, size_t embedding_size, size_t block_size) { + auto cache_block_mem_sw = get_block_memory(num_heads, block_size, embedding_size, Rank3Matrix{}); + auto rotation_coeffts_block_mem = get_block_memory(block_size, embedding_size, Rank2Matrix{}); + + std::mt19937 engine; + engine.seed(0); + std::uniform_real_distribution rng(-2.0, 2.0); + + auto raw_mem_ptr_sw = cache_block_mem_sw.get(); + auto raw_rotation_coefficients_mem_ptr = rotation_coeffts_block_mem.get(); + + auto generate_fn = [&]() { + return TypeParam(rng(engine)); + }; + + std::generate(raw_mem_ptr_sw, raw_mem_ptr_sw + num_heads * block_size * embedding_size, generate_fn); + // coeffts are now not strictly sine-cosine pairs, but it does not matter for the kernels + std::generate(raw_rotation_coefficients_mem_ptr, + raw_rotation_coefficients_mem_ptr + block_size * embedding_size, + generate_fn); + + auto cache_block_mem_hw = get_block_memory(num_heads, block_size, embedding_size, Rank3Matrix{}); + auto raw_mem_ptr_hw = cache_block_mem_hw.get(); + std::copy(raw_mem_ptr_sw, raw_mem_ptr_sw + num_heads * block_size * embedding_size, raw_mem_ptr_hw); + + ov::intel_cpu::PerfCount counter; + { + ov::intel_cpu::PerfHelper helper(counter); + rotate_kv_cache_block_hw(raw_mem_ptr_hw, + rotation_coeffts_block_mem.get(), + num_heads, + block_size, + embedding_size); + } + + { + ov::intel_cpu::PerfHelper helper(counter); + rotate_kv_cache_block_sw(raw_mem_ptr_sw, + rotation_coeffts_block_mem.get(), + num_heads, + block_size, + embedding_size); + } + + auto sw_values_after_rotation = get_matrix_from_mem(cache_block_mem_sw, num_heads, block_size, embedding_size); + auto hw_values_after_rotation = get_matrix_from_mem(cache_block_mem_hw, num_heads, block_size, embedding_size); + compare_with_tolerance(hw_values_after_rotation, sw_values_after_rotation, get_tolerance()); + } +}; + +using OV_FP_TYPES = ::testing::Types; + +TYPED_TEST_SUITE(CacheRotationKernelTest, OV_FP_TYPES); + +TYPED_TEST(CacheRotationKernelTest, SWBlockRotationGivesReferenceResults) { + auto raw_cache_mem_ptr = this->cache_mem_ptr.get(); + auto raw_rotation_coefficients_mem_ptr = this->rotation_coefficients_mem_ptr.get(); + + rotate_kv_cache_block_sw(raw_cache_mem_ptr, + raw_rotation_coefficients_mem_ptr, + this->num_heads, + this->block_size, + this->embedding_size); + + auto test_values_after_rotation = + get_matrix_from_mem(this->cache_mem_ptr, this->num_heads, this->block_size, this->embedding_size); + compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, get_tolerance()); +} + +enum class TargetInstructionSet { AVX2, AVX512 }; + +MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") { + if (ref_container.size() < n || arg.size() < n) + return false; + if (ref_container.size() != arg.size()) + return false; + + bool is_ok = true; + for (size_t i = 0; i < n; i++) { + if (!::testing::ExplainMatchResult(::testing::FloatNear(static_cast(arg[i]), abs_err), + static_cast(ref_container[i]), + result_listener)) { + *result_listener << " for element at idx " << i << '\n'; + is_ok = false; + } + } + return is_ok; +} + +class CacheRotationHWKernelTest : public ::testing::TestWithParam> { +protected: + constexpr static size_t MAX_CHUNK_SIZE_IN_ELEMENTS = 16; + template + using MemChunk = std::array; + + template + void test_chunk_rotation_for_type() { + auto instruction_set = std::get<0>(GetParam()); + auto num_elements_to_process = std::get<1>(GetParam()); + + MemChunk chunk_x = {-0.76777814, + 0.97583583, + -0.23619731, + 0.19022397, + 0.56691264, + 0.64870757, + 0.63334306, + 1.97307894, + 0.72495168, + 1.22328697, + -0.6005607, + 0.17189973, + -0.92268487, + 0.40205632, + 0.85996431, + 1.70078315}; + + MemChunk chunk_y = {1.68812157, + -0.90722836, + 0.58474063, + -0.64561766, + 0.62651501, + 1.55990472, + 0.41571189, + 0.38366555, + 0.09841767, + 0.02218336, + -0.07657361, + 1.6062845, + -1.08282323, + -0.92034808, + -1.48428038, + 0.43501142}; + + MemChunk chunk_cos = {-0.87461971, + 0.95630476, + 0.08715574, + 0.8480481, + -0.9612617, + 0.27563736, + 0.97437006, + 0.66913061, + -0.89100652, + 0.98480775, + -0.7313537, + -0.2419219, + 0.10452846, + 0.70710678, + -0.32556815, + -0.2923717}; + + MemChunk chunk_sin = {-0.48480962, + -0.2923717, + 0.9961947, + 0.52991926, + 0.27563736, + -0.9612617, + -0.22495105, + 0.74314483, + 0.4539905, + -0.17364818, + -0.68199836, + -0.97029573, + -0.9945219, + -0.70710678, + -0.94551858, + 0.95630476}; + + MemChunk ref_chunk_cos = chunk_cos; + MemChunk ref_chunk_sin = chunk_sin; + + MemChunk ref_chunk_x = {1.48993147, + 0.66794854, + -0.60310147, + 0.50344431, + -0.71764235, + 1.6782847, + 0.71062535, + 1.03512844, + -0.69061736, + 1.20855459, + 0.38699921, + 1.51698468, + -1.17333824, + -0.36648762, + -1.68339166, + -0.91326436}; + + MemChunk ref_chunk_y = {-1.10423816, + -1.15289358, + -0.184335, + -0.44671148, + -0.44598258, + -0.19360973, + 0.26258603, + 1.72300577, + 0.24143039, + -0.19057521, + 0.46558381, + -0.55538896, + 0.80444446, + -0.93508112, + -0.32987781, + 1.49928198}; + + // unprocessed elements should remain untouched + std::copy(chunk_x.begin() + num_elements_to_process, + chunk_x.end(), + ref_chunk_x.begin() + num_elements_to_process); + std::copy(chunk_y.begin() + num_elements_to_process, + chunk_y.end(), + ref_chunk_y.begin() + num_elements_to_process); + + switch (instruction_set) { + using namespace ov::Extensions::Cpu::XARCH; + case TargetInstructionSet::AVX2: + rotate_kv_cache_chunk_avx2(chunk_x.data(), + chunk_y.data(), + chunk_cos.data(), + chunk_sin.data(), + num_elements_to_process, + /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx2); + break; + case TargetInstructionSet::AVX512: + rotate_kv_cache_chunk_avx512(chunk_x.data(), + chunk_y.data(), + chunk_cos.data(), + chunk_sin.data(), + num_elements_to_process, + /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx512); + break; + default: + FAIL() << "unknown target instruction set"; + } + + std::string type_name = ov::element::from().to_string(); + + EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, get_tolerance(), num_elements_to_process)) + << ", element type is: " << type_name; + EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, get_tolerance(), num_elements_to_process)) + << ", element type is: " << type_name; + + EXPECT_EQ(chunk_cos, ref_chunk_cos) << ", element type is: " << type_name; + EXPECT_EQ(chunk_sin, ref_chunk_sin) << ", element type is: " << type_name; + } +}; + +TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) { + test_chunk_rotation_for_type(); + test_chunk_rotation_for_type(); + test_chunk_rotation_for_type(); +} + +auto TEST_STRUCT_TO_NAME_FN = [](const testing::TestParamInfo& info) { + size_t num_elts = std::get<1>(info.param); + switch (std::get<0>(info.param)) { + case TargetInstructionSet::AVX2: + return std::string("avx2-") + std::to_string(num_elts); + case TargetInstructionSet::AVX512: + return std::string("avx512-") + std::to_string(num_elts); + } + return std::string("unknown"); +}; + +INSTANTIATE_TEST_SUITE_P(AVX2, + CacheRotationHWKernelTest, + ::testing::Combine(::testing::Values(TargetInstructionSet::AVX2), + ::testing::Range(size_t(0), + ov::Extensions::Cpu::XARCH::vec_len_f32_avx2 + 1)), + TEST_STRUCT_TO_NAME_FN); +INSTANTIATE_TEST_SUITE_P(AVX512, + CacheRotationHWKernelTest, + ::testing::Combine(::testing::Values(TargetInstructionSet::AVX512), + ::testing::Range(size_t(0), + ov::Extensions::Cpu::XARCH::vec_len_f32_avx512 + 1)), + TEST_STRUCT_TO_NAME_FN); + +TYPED_TEST(CacheRotationKernelTest, HWBlockRotationGivesReferenceResults) { + auto raw_cache_mem_ptr = this->cache_mem_ptr.get(); + auto raw_rotation_coefficients_mem_ptr = this->rotation_coefficients_mem_ptr.get(); + + rotate_kv_cache_block_hw(raw_cache_mem_ptr, + raw_rotation_coefficients_mem_ptr, + this->num_heads, + this->block_size, + this->embedding_size); + + auto test_values_after_rotation = + get_matrix_from_mem(this->cache_mem_ptr, this->num_heads, this->block_size, this->embedding_size); + compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, get_tolerance()); +} + +TYPED_TEST(CacheRotationKernelTest, HWBlockRotationIsSimilarToSW) { + // short case + this->test_block_hw_vs_sw(/* num_heads = */ 4, /* embedding_size = */ 64, /* block_size = */ 2); + + // long case + this->test_block_hw_vs_sw(256, 1024, 32); +} diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index f87f608597a6bb..c9743f6a4d8ffe 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -21,7 +21,9 @@ struct paged_attention : public primitive_base { paged_attention(const primitive_id& id, const std::vector& inputs) : primitive_base(id, inputs) { - OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size()); + OPENVINO_ASSERT((inputs.size() == 13) || (inputs.size() == 15), + "[GPU] Unexpected inputs number for PagedAttention primitive: ", + inputs.size()); } bool operator==(const primitive& rhs) const override { @@ -34,6 +36,7 @@ struct paged_attention : public primitive_base { ob << heads_num; ob << kv_heads_num; ob << has_alibi; + ob << has_rotation_coefficients; } void load(BinaryInputBuffer& ib) override { @@ -42,6 +45,7 @@ struct paged_attention : public primitive_base { ib >> heads_num; ib >> kv_heads_num; ib >> has_alibi; + ib >> has_rotation_coefficients; } optional_value scale_val{}; @@ -49,5 +53,6 @@ struct paged_attention : public primitive_base { size_t heads_num = 0; size_t kv_heads_num = 0; bool has_alibi = false; + bool has_rotation_coefficients = false; }; } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp index 9cf1a252564934..95291a520ef7a8 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp @@ -155,6 +155,11 @@ struct paged_attention_impl : multi_stage_primitive { if (desc->has_alibi) { args.inputs.push_back(instance.alibi_memory_ptr()); } + + if (desc->has_rotation_coefficients) { + args.inputs.push_back(instance.rotation_coefficients_memory_ptr()); + args.inputs.push_back(instance.rotated_block_indices_memory_ptr()); + } } else { args.inputs = { instance.past_lens_memory_ptr() }; diff --git a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h index a7918ba9c3719c..af653102f28da9 100644 --- a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h @@ -62,6 +62,13 @@ class typed_primitive_inst : public typed_primitive_inst_base

prefill_network; protected: diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 787fd184f75b6a..c91adb61721832 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -74,6 +74,7 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) { paged_attention_info.add("kv_heads_num", desc->kv_heads_num); paged_attention_info.add("scale", desc->scale_val.value_or(1.0f)); paged_attention_info.add("has_alibi", desc->has_alibi); + paged_attention_info.add("has_rotation_coefficients", desc->has_rotation_coefficients); node_info->add("paged_attention primitive info", paged_attention_info); node_info->dump(primitive_description); diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl index 00c43829d02ea7..b38959de0220ef 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl @@ -42,6 +42,10 @@ KERNEL(pa_sdpa_opt)( #endif #if HAS_ALIBI const __global ALIBI_INPUT_TYPE* alibi_slopes, +#endif +#if HAS_ROTATION_COEFFICIENTS + const __global INPUT8_TYPE* rotation_coefficients, + const __global INPUT9_TYPE* rotated_block_indices, #endif __global OUTPUT_TYPE* output, __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, @@ -58,7 +62,9 @@ KERNEL(pa_sdpa_opt)( // past_lens: [sequences_num] // subsequence_begins: [sequences_num + 1] // block_indices: [used_blocks_num] - // block_indices: [sequences_num + 1] + // block_indices_begins: [sequences_num + 1] + // rotation_coefficients: [num_rotated_blocks * PAGED_ATTENTION_BLOCK_SIZE] + // rotated_block_indices: [num_rotated_blocks ] // // Output shapes: // output: [sequences_num, HEADS_NUM * HEAD_SIZE] @@ -144,6 +150,10 @@ KERNEL(pa_sdpa_opt)( } #endif +#ifdef HAS_ROTATION_COEFFICIENTS + // TODO (vshampor): add cache block rotation at this spot +#endif + const uint blocks_num_per_partition = min(total_blocks_num - partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION, (uint)PAGED_ATTENTION_BLOCKS_PER_PARTITION); uint blocks_num = blocks_num_per_partition / SUBGROUPS_PER_WG; diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index 55f87e4189d9fe..e41d42401adae9 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -172,7 +172,7 @@ KERNEL(sdpa_opt)( #endif #if SUBGROUPS_PER_WG > SUBGROUP_SIZE - #error "sdpa_opt.cl: Number of subgroups per work group should be less than subgroup_size + #error "sdpa_opt.cl: Number of subgroups per work group should be less than subgroup_size" #endif const uint sgid = get_sub_group_id(); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp index 63c5e74160f652..0c34943e734d84 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp @@ -190,6 +190,9 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params& jit.Merge(MakeTypeJitConstants(params.inputs[alibi_input_idx].GetDType(), "ALIBI_INPUT")); } + if (params.conf.has_rotation_coefficients_input) + jit.AddConstant(MakeJitConstant("HAS_ROTATION_COEFFICIENTS", 1)); + if (kernel_idx == KernelsTypes::MULTI_TOKENS || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS) jit.AddConstant(MakeJitConstant("MULTI_TOKENS_PROCESSING", 1)); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h index 493bd0acedea32..61852e5eec33ce 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h @@ -99,6 +99,7 @@ struct sdpa_configuration { int64_t paged_attention_block_size = 0; bool has_const_scale_val = false; float scale_val = 0.f; + bool has_rotation_coefficients_input = false; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 7425b096b6d324..a013a8ad315dfc 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -48,6 +48,7 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared const size_t scale_idx = 9; const size_t alibi_idx = 11; + const size_t rotation_coefficients_idx = 13; std::shared_ptr scale_const = std::dynamic_pointer_cast(op->get_input_node_shared_ptr(scale_idx)); if (scale_const) { @@ -61,6 +62,11 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared OPENVINO_ASSERT(alibi_const != nullptr); prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0; + std::shared_ptr rotation_coefficients_const = + std::dynamic_pointer_cast(op->get_input_node_shared_ptr(rotation_coefficients_idx)); + OPENVINO_ASSERT(rotation_coefficients_const != nullptr); + prim.has_rotation_coefficients = ov::shape_size(rotation_coefficients_const->get_output_shape(0)) > 0; + if (op->get_output_size() > 1) { const auto scores_output_idx = 1; const auto& users = op->get_output_target_inputs(scores_output_idx); diff --git a/src/plugins/intel_npu/thirdparty/level-zero-ext b/src/plugins/intel_npu/thirdparty/level-zero-ext index a63155ae4e64fe..a6487cc2c5da9a 160000 --- a/src/plugins/intel_npu/thirdparty/level-zero-ext +++ b/src/plugins/intel_npu/thirdparty/level-zero-ext @@ -1 +1 @@ -Subproject commit a63155ae4e64feaaa6931f4696c2e2e699063875 +Subproject commit a6487cc2c5da9aa13db9e005a320a1b6a0ee5919 diff --git a/tests/model_hub_tests/transformation_tests/generate_ref_diffs.py b/tests/model_hub_tests/transformation_tests/generate_ref_diffs.py index 72051783fa7422..785591aeddde29 100644 --- a/tests/model_hub_tests/transformation_tests/generate_ref_diffs.py +++ b/tests/model_hub_tests/transformation_tests/generate_ref_diffs.py @@ -106,7 +106,7 @@ def main(): # wrapping in try/catch block to continue printing models even if one has failed try: - paged_attention_transformation(ov_model, use_cache_eviction, use_cache_eviction) + paged_attention_transformation(ov_model, use_cache_eviction, use_cache_eviction, use_cache_eviction) except: print(f"Couldn't run SDPAToPA transformation on {model_id} and generate diffs.") continue @@ -117,10 +117,12 @@ def main(): after_map[op.get_type_name()] = after_map.get(op.get_type_name(), 0) + 1 print(f'\t"{model_id}" : {{', file=file) - for op in set(after_map.keys()) | set(before_map.keys()): + for op in sorted(set(after_map.keys()) | set(before_map.keys())): print(f'\t\t"{op}" : {after_map.get(op, 0) - before_map.get(op, 0)},', file=file) print('\t},', file=file) print('}', file=file) + print(f"output written to {OUTPUT_FILE}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/model_hub_tests/transformation_tests/sdpa2pa_ref_diff.py b/tests/model_hub_tests/transformation_tests/sdpa2pa_ref_diff.py index 43ef49d9b5a226..56eb1181e29f9b 100644 --- a/tests/model_hub_tests/transformation_tests/sdpa2pa_ref_diff.py +++ b/tests/model_hub_tests/transformation_tests/sdpa2pa_ref_diff.py @@ -5,305 +5,298 @@ ref_diff_map = { "hf-internal-testing/tiny-random-LlamaForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-CohereForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-GPTJForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 13, "ReadValue" : -10, - "Assign" : -10, - }, - "hf-internal-testing/tiny-random-GPTNeoForCausalLM" : { - "PagedAttentionExtension" : 4, - "ScaledDotProductAttention" : -4, - "Parameter" : 11, - "ReadValue" : -8, - "Assign" : -8, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-GPTNeoXForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 13, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-MistralForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-CodeGenForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 13, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/Mixtral-tiny" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM" : { + "Assign" : -5, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 13, "ReadValue" : -5, - "Assign" : -5, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-Starcoder2ForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-BloomForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 14, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-gpt2" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 13, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-BlenderbotForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 8, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PegasusForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 8, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PhiForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-MptForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 14, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-StableLmForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PersimmonForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-FalconForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-tiny-model-private/tiny-random-OPTForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 14, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "katuni4ka/tiny-random-xverse" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-baichuan2-13b" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-qwen" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-aquilachat" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-aquila2" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-qwen1.5-moe" : { + "Assign" : -8, "PagedAttentionExtension" : 4, - "ScaledDotProductAttention" : -4, "Parameter" : 11, "ReadValue" : -8, - "Assign" : -8, + "ScaledDotProductAttention" : -4, }, "katuni4ka/tiny-random-codegen2" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-olmo-hf" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-baichuan2" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-jais" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-internlm" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-internlm2" : { + "Assign" : -8, "PagedAttentionExtension" : 4, - "ScaledDotProductAttention" : -4, "Parameter" : 11, "ReadValue" : -8, + "ScaledDotProductAttention" : -4, + }, + "katuni4ka/tiny-random-minicpm" : { "Assign" : -8, + "PagedAttentionExtension" : 4, + "Parameter" : 11, + "ReadValue" : -8, + "ScaledDotProductAttention" : -4, }, - "katuni4ka/tiny-random-minicpm" : { - "ReadValue" : -8, - "ScaledDotProductAttention" : -4, - "Assign" : -8, - "PagedAttentionExtension" : 4, - "Parameter" : 11, - }, "katuni4ka/tiny-random-falcon-40b" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-dbrx" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "fxmarty/tiny-random-GemmaForCausalLM" : { + "Assign" : -2, "PagedAttentionExtension" : 1, - "ScaledDotProductAttention" : -1, "Parameter" : 5, "ReadValue" : -2, - "Assign" : -2, + "ScaledDotProductAttention" : -1, }, "fxmarty/tiny-dummy-qwen2" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "fxmarty/really-tiny-falcon-testing" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "Xenova/tiny-random-Phi3ForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "facebook/opt-125m" : { + "Assign" : -24, "PagedAttentionExtension" : 12, - "ScaledDotProductAttention" : -12, "Parameter" : 28, "ReadValue" : -24, - "Assign" : -24, + "ScaledDotProductAttention" : -12, }, "facebook/opt-350m" : { + "Assign" : -48, "PagedAttentionExtension" : 24, - "ScaledDotProductAttention" : -24, "Parameter" : 52, "ReadValue" : -48, - "Assign" : -48, + "ScaledDotProductAttention" : -24, }, "katuni4ka/tiny-random-chatglm2" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-glm4" : { + "Assign" : -12, "PagedAttentionExtension" : 6, - "ScaledDotProductAttention" : -6, "Parameter" : 15, "ReadValue" : -12, - "Assign" : -12, + "ScaledDotProductAttention" : -6, }, "katuni4ka/tiny-random-llava-next" : { "PagedAttentionExtension" : 2, @@ -338,305 +331,305 @@ ref_diff_map_cache_eviction = { "hf-internal-testing/tiny-random-LlamaForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-CohereForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-GPTJForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 17, "Assign" : -10, - }, - "hf-internal-testing/tiny-random-GPTNeoForCausalLM" : { - "ScaledDotProductAttention" : -4, - "ReadValue" : -8, - "PagedAttentionExtension" : 4, - "Parameter" : 14, - "Assign" : -8, + "PagedAttentionExtension" : 5, + "Parameter" : 27, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-GPTNeoXForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 17, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 27, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-MistralForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-CodeGenForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 17, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 27, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/Mixtral-tiny" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -5, - "PagedAttentionExtension" : 5, - "Parameter" : 17, "Assign" : -5, + "PagedAttentionExtension" : 5, + "Parameter" : 27, + "ReadValue" : -5, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-Starcoder2ForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-BloomForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 18, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 28, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-gpt2" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 17, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 27, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-BlenderbotForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 9, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PegasusForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 9, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PhiForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-MptForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 18, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 28, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-StableLmForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PersimmonForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-FalconForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-tiny-model-private/tiny-random-OPTForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 18, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 28, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "katuni4ka/tiny-random-xverse" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-baichuan2-13b" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-qwen" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-aquilachat" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-aquila2" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-qwen1.5-moe" : { - "ScaledDotProductAttention" : -4, - "ReadValue" : -8, - "PagedAttentionExtension" : 4, - "Parameter" : 14, "Assign" : -8, + "PagedAttentionExtension" : 4, + "Parameter" : 22, + "ReadValue" : -8, + "ScaledDotProductAttention" : -4, }, "katuni4ka/tiny-random-codegen2" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-olmo-hf" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-baichuan2" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-jais" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-internlm" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-internlm2" : { - "ScaledDotProductAttention" : -4, - "ReadValue" : -8, - "PagedAttentionExtension" : 4, - "Parameter" : 14, "Assign" : -8, + "PagedAttentionExtension" : 4, + "Parameter" : 22, + "ReadValue" : -8, + "ScaledDotProductAttention" : -4, }, "katuni4ka/tiny-random-minicpm" : { - "ScaledDotProductAttention" : -4, - "Parameter" : 14, + "Assign" : -8, "PagedAttentionExtension" : 4, + "Parameter" : 22, "ReadValue" : -8, + "ScaledDotProductAttention" : -4, + }, + "katuni4ka/tiny-random-minicpm" : { "Assign" : -8, + "PagedAttentionExtension" : 4, + "Parameter" : 14, + "ReadValue" : -8, + "ScaledDotProductAttention" : -4, }, "katuni4ka/tiny-random-falcon-40b" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-dbrx" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "fxmarty/tiny-random-GemmaForCausalLM" : { - "ScaledDotProductAttention" : -1, - "ReadValue" : -2, - "PagedAttentionExtension" : 1, - "Parameter" : 5, "Assign" : -2, + "PagedAttentionExtension" : 1, + "Parameter" : 7, + "ReadValue" : -2, + "ScaledDotProductAttention" : -1, }, "fxmarty/tiny-dummy-qwen2" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "fxmarty/really-tiny-falcon-testing" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "Xenova/tiny-random-Phi3ForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "facebook/opt-125m" : { - "ScaledDotProductAttention" : -12, - "ReadValue" : -24, - "PagedAttentionExtension" : 12, - "Parameter" : 39, "Assign" : -24, + "PagedAttentionExtension" : 12, + "Parameter" : 63, + "ReadValue" : -24, + "ScaledDotProductAttention" : -12, }, "facebook/opt-350m" : { - "ScaledDotProductAttention" : -24, - "ReadValue" : -48, - "PagedAttentionExtension" : 24, - "Parameter" : 75, "Assign" : -48, + "PagedAttentionExtension" : 24, + "Parameter" : 123, + "ReadValue" : -48, + "ScaledDotProductAttention" : -24, }, "katuni4ka/tiny-random-chatglm2" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 12, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-glm4" : { - "ScaledDotProductAttention" : -6, - "ReadValue" : -12, - "PagedAttentionExtension" : 6, - "Parameter" : 20, "Assign" : -12, + "PagedAttentionExtension" : 6, + "Parameter" : 32, + "ReadValue" : -12, + "ScaledDotProductAttention" : -6, }, "katuni4ka/tiny-random-llava-next" : { "Parameter" : 8, diff --git a/tests/model_hub_tests/transformation_tests/test_pa_transformation.py b/tests/model_hub_tests/transformation_tests/test_pa_transformation.py index 2bc6726dff030f..797820422094f9 100644 --- a/tests/model_hub_tests/transformation_tests/test_pa_transformation.py +++ b/tests/model_hub_tests/transformation_tests/test_pa_transformation.py @@ -17,13 +17,14 @@ def compare_diffs(ov_model: ov.Model, model_id: str, use_block_indices_inputs: bool, - use_score_outputs: bool): + use_score_outputs: bool, + allow_cache_rotation: bool): before_map = {} for op in ov_model.get_ordered_ops(): if op.get_type_name() in nodes_to_compare: before_map[op.get_type_name()] = before_map.get(op.get_type_name(), 0) + 1 - paged_attention_transformation(ov_model, use_block_indices_inputs, use_score_outputs) + paged_attention_transformation(ov_model, use_block_indices_inputs, use_score_outputs, allow_cache_rotation) after_map = {} for op in ov_model.get_ordered_ops(): @@ -51,32 +52,43 @@ def compare_diffs(ov_model: ov.Model, assert shape[-1].is_static, f"Dimension {len(shape) - 1} of input '{name}' in '{model_id}' is not static: {shape}" assert shape[-2].is_static, f"Dimension {len(shape) - 2} of input '{name}' in '{model_id}' is not static: {shape}" + interesting_input_patterns = {} + interesting_output_patterns = {} + + # Test for block_indices inputs and scores outputs to appear in the model if (use_block_indices_inputs): - block_indices_pattern = r'block_indices\.[0-9]+' - block_indices_counter = 0 - - model_inputs = ov_model.inputs - for input in model_inputs: - for name in list(input.get_names()): - if re.search(block_indices_pattern, name): - block_indices_counter += 1 - - assert block_indices_counter == resulting_map["PagedAttentionExtension"], \ - f"The number of block_indices inputs doesn't correspond to the expected value. Expected {resulting_map['PagedAttentionExtension']}, received {block_indices_counter}" - + interesting_input_patterns["block_indices"] = r'block_indices\.[0-9]+' + if (use_score_outputs): - score_pattern = r'scores\.[0-9]+' - score_outputs_counter = 0 + interesting_output_patterns["scores"] = r'scores\.[0-9]+' + + if (allow_cache_rotation): + interesting_input_patterns["rotation_coefficients"] = r'rotation_coefficients\.[0-9]+'; + interesting_input_patterns["rotated_block_indices"] = r'rotated_block_indices\.[0-9]+'; + + input_counters = {k: 0 for k in interesting_input_patterns} + output_counters = {k: 0 for k in interesting_output_patterns} + + for pattern_dict, counter_dict, io_set in zip([interesting_input_patterns, interesting_output_patterns], + [input_counters, output_counters], + [ov_model.inputs, ov_model.outputs]): + for input_id in counter_dict: + pattern = pattern_dict[input_id] + for model_io in io_set: + for name in list(model_io.get_names()): + if re.search(pattern, name): + counter_dict[input_id] += 1 + + for input_id, count in input_counters.items(): + assert count == resulting_map["PagedAttentionExtension"], \ + f"The number of {input_id} inputs doesn't correspond to the expected value. Expected {resulting_map['PagedAttentionExtension']}, received {count}" - model_outputs = ov_model.outputs - for output in model_outputs: - for name in list(output.get_names()): - if re.search(score_pattern, name): - score_outputs_counter += 1 + for output_id, count in output_counters.items(): + assert count == resulting_map["PagedAttentionExtension"], \ + f"The number of {output_id} outputs doesn't correspond to the expected value. Expected {resulting_map['PagedAttentionExtension']}, received {count}" +>>>>>>> cb34f79532 (Add cache rotation inputs in transformations, CPU and GPU plugins) - assert block_indices_counter == resulting_map["PagedAttentionExtension"], \ - f"The number of scores outputs doesn't correspond to the expected value. Expected {resulting_map['PagedAttentionExtension']}, received {block_indices_counter}" @retry(3, exceptions=(OSError,), delay=1) def run_pa(tmp_path, @@ -99,7 +111,7 @@ def test_pa_precommit(tmp_path, model_name, model_link, mark, reason, ie_device) pytest.skip(reason) elif mark == 'xfail': pytest.xfail(reason) - run_pa(tmp_path, model_name, model_link, OVModelForCausalLM, False, False) + run_pa(tmp_path, model_name, model_link, OVModelForCausalLM False, False, False) @pytest.mark.precommit @pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit"))) @@ -121,7 +133,7 @@ def test_pa_vlm(tmp_path, model_name, model_link, mark, reason, ie_device): pytest.skip(reason) elif mark == 'xfail': pytest.xfail(reason) - run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, False, False) + run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, False, False, False) @pytest.mark.precommit @pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-vl-models-precommit"))) @@ -132,4 +144,4 @@ def test_pa_vlm_use_cache_eviction(tmp_path, model_name, model_link, mark, reaso pytest.skip(reason) elif mark == 'xfail': pytest.xfail(reason) - run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, True, True) \ No newline at end of file + run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, True, True, True)