diff --git a/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp b/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp index 7fe53c0800178c..0e1c2387af43c2 100644 --- a/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp +++ b/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp @@ -132,9 +132,14 @@ void regmodule_offline_transformations(py::module m) { m_offline_transformations.def( "paged_attention_transformation", - [](std::shared_ptr model, bool use_block_indices_inputs, bool use_score_outputs, bool allow_cache_rotation) { + [](std::shared_ptr model, + bool use_block_indices_inputs, + bool use_score_outputs, + bool allow_cache_rotation) { ov::pass::Manager manager; - manager.register_pass(use_block_indices_inputs, use_score_outputs, allow_cache_rotation); + manager.register_pass(use_block_indices_inputs, + use_score_outputs, + allow_cache_rotation); manager.run_passes(model); }, py::arg("model"), diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp index 2e7327fc8a2f2b..e47ec4731a17d6 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp @@ -64,7 +64,7 @@ static node_tuple kv_read_and_concat(ov::Output kv_current) { return node_tuple(kv_past_par, kv_current2, kv_current_reshaped, kv_concat); } -template +template void insert_rotation_inputs_as(OutputVector& pa_arguments, size_t layer_index) { auto rotation_coefficients = setName(std::make_shared(ov::element::f32, ov::PartialShape{-1}), "rotation_coefficients." + std::to_string(layer_index - 1)); @@ -194,8 +194,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par &score_results, &layer_index, &rotation_coefficients_inputs_for_each_layer, - &rotated_block_indices_inputs_for_each_layer - ](ov::pass::pattern::Matcher& m) { + &rotated_block_indices_inputs_for_each_layer](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); auto real_q = pattern_map.at(q); @@ -400,8 +399,6 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par OPENVINO_ASSERT(pa_arguments.size() == 13); - - if (allow_cache_rotation) { auto rotation_coefficients = setName(std::make_shared(element::f32, PartialShape{-1}), "rotation_coefficients." + std::to_string(layer_index - 1)); diff --git a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp index 9a2d43f617a30f..ce7eb78d079632 100644 --- a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp +++ b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp @@ -19,7 +19,8 @@ class OPENVINO_API SDPAToPagedAttention : public ModelPass { public: OPENVINO_RTTI("SDPAToPagedAttention"); - SDPAToPagedAttention(bool use_per_layer_block_indices_inputs = false, bool use_score_outputs = false, + SDPAToPagedAttention(bool use_per_layer_block_indices_inputs = false, + bool use_score_outputs = false, bool allow_cache_rotation = false); bool run_on_model(const std::shared_ptr& model) override; diff --git a/src/core/src/op/paged_attention.cpp b/src/core/src/op/paged_attention.cpp index e19a8f0cc989db..3c82d86817b51c 100644 --- a/src/core/src/op/paged_attention.cpp +++ b/src/core/src/op/paged_attention.cpp @@ -149,12 +149,11 @@ void PagedAttentionExtension::validate_and_infer_types() { if (get_input_size() == 15) { NODE_VALIDATION_CHECK( - this, - get_input_partial_shape(13).rank().is_dynamic() || - get_input_partial_shape(13).rank().get_length() == 1, - "Input `rotation_coefficients` should either have rank 1 or be omitted, but it has rank ", - get_input_partial_shape(13).rank().get_length(), - "."); + this, + get_input_partial_shape(13).rank().is_dynamic() || get_input_partial_shape(13).rank().get_length() == 1, + "Input `rotation_coefficients` should either have rank 1 or be omitted, but it has rank ", + get_input_partial_shape(13).rank().get_length(), + "."); NODE_VALIDATION_CHECK(this, get_input_element_type(13).is_dynamic() || get_input_element_type(13) == element::f32, "Element type of `rotation_coefficients` input should be f32, but it is ", @@ -162,12 +161,11 @@ void PagedAttentionExtension::validate_and_infer_types() { "."); NODE_VALIDATION_CHECK( - this, - get_input_partial_shape(14).rank().is_dynamic() || - get_input_partial_shape(14).rank().get_length() == 1, - "Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ", - get_input_partial_shape(14).rank().get_length(), - "."); + this, + get_input_partial_shape(14).rank().is_dynamic() || get_input_partial_shape(14).rank().get_length() == 1, + "Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ", + get_input_partial_shape(14).rank().get_length(), + "."); NODE_VALIDATION_CHECK(this, get_input_element_type(14).is_dynamic() || get_input_element_type(14) == element::i32, "Element type of `rotated_block_indices` input should be i32, but it is ", @@ -175,7 +173,6 @@ void PagedAttentionExtension::validate_and_infer_types() { "."); } - // value head_size may be not same with key auto out_ps = get_input_partial_shape(0); const auto& key_ps = get_input_partial_shape(1); diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index 27a7c11d536dfc..dc01d0ba317ff6 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -20,7 +20,8 @@ using namespace ov::op; -ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_indices_inputs, bool use_score_outputs, +ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_indices_inputs, + bool use_score_outputs, bool allow_cache_rotation) : m_use_per_layer_block_indices_inputs(use_per_layer_block_indices_inputs), m_use_score_outputs(use_score_outputs), diff --git a/src/core/tests/type_prop/paged_attention.cpp b/src/core/tests/type_prop/paged_attention.cpp index ccd273c57132de..b1114b71ad8c8c 100644 --- a/src/core/tests/type_prop/paged_attention.cpp +++ b/src/core/tests/type_prop/paged_attention.cpp @@ -2,11 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/op/paged_attention.hpp" + #include #include "common_test_utils/test_assertions.hpp" #include "common_test_utils/type_prop.hpp" -#include "openvino/op/paged_attention.hpp" #include "openvino/openvino.hpp" #include "openvino/opsets/opset13.hpp" @@ -28,8 +29,19 @@ TEST(type_prop, paged_attention_static_13_inputs) { const auto alibi_slopes = std::make_shared(element::f32, Shape{9}); const auto max_context_len = std::make_shared(element::i32, Shape{}); - - ov::OutputVector args = {query, key, value, key_cache, value_cache, past_lens, subsequence_begins, block_indices, block_indices_begins, scale, sliding_window, alibi_slopes, max_context_len}; + ov::OutputVector args = {query, + key, + value, + key_cache, + value_cache, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len}; const auto op = std::make_shared(args); EXPECT_EQ(op->get_output_element_type(0), element::f32); EXPECT_EQ(op->get_output_partial_shape(0), (Shape{3, 4})); @@ -53,10 +65,23 @@ TEST(type_prop, paged_attention_static_15_inputs) { const auto rotation_coefficients = std::make_shared(element::f32, Shape{12}); const auto rotated_block_indices = std::make_shared(element::i32, Shape{3}); - ov::OutputVector args = {query, key, value, key_cache, value_cache, past_lens, subsequence_begins, block_indices, block_indices_begins, scale, sliding_window, alibi_slopes, max_context_len, rotation_coefficients, rotated_block_indices}; + ov::OutputVector args = {query, + key, + value, + key_cache, + value_cache, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + rotation_coefficients, + rotated_block_indices}; const auto op = std::make_shared(args); EXPECT_EQ(op->get_output_element_type(0), element::f32); EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 4})); } - diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp index 3fe19ea15af05b..28f7d9c9ec0ca0 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp @@ -3,16 +3,21 @@ // #pragma once -#include "openvino/openvino.hpp" #include "common.hpp" +#include "openvino/openvino.hpp" #if defined(HAVE_AVX2) || defined(HAVE_AVX512F) # include #endif #if defined(HAVE_AVX512F) -template -inline static void rotate_kv_cache_chunk_avx512(CT* current_x_values_ptr, CT* current_y_values_ptr, float* current_rotation_coeffts_cos_ptr, float* current_rotation_coeffts_sin_ptr, size_t num_vectorized_elements_per_iteration, bool is_underutilizing) { +template +inline static void rotate_kv_cache_chunk_avx512(CT* current_x_values_ptr, + CT* current_y_values_ptr, + float* current_rotation_coeffts_cos_ptr, + float* current_rotation_coeffts_sin_ptr, + size_t num_vectorized_elements_per_iteration, + bool is_underutilizing) { using namespace ov::Extensions::Cpu::XARCH; auto result_x = _mm512_setzero_ps(); @@ -30,8 +35,7 @@ inline static void rotate_kv_cache_chunk_avx512(CT* current_x_values_ptr, CT* cu cache_values_x = mm512_uni_loadu_ps(current_x_values_ptr); cache_values_y = mm512_uni_loadu_ps(current_y_values_ptr); - } - else { + } else { coeffts_cos = mm512_uni_loadu_tail_ps(current_rotation_coeffts_cos_ptr, num_vectorized_elements_per_iteration); coeffts_sin = mm512_uni_loadu_tail_ps(current_rotation_coeffts_sin_ptr, num_vectorized_elements_per_iteration); @@ -40,7 +44,7 @@ inline static void rotate_kv_cache_chunk_avx512(CT* current_x_values_ptr, CT* cu } result_x = _mm512_fmadd_ps(cache_values_x, coeffts_cos, result_x); - result_x = _mm512_fnmadd_ps(cache_values_y, coeffts_sin, result_x); // negative multiply-add + result_x = _mm512_fnmadd_ps(cache_values_y, coeffts_sin, result_x); // negative multiply-add result_y = _mm512_fmadd_ps(cache_values_x, coeffts_sin, result_y); result_y = _mm512_fmadd_ps(cache_values_y, coeffts_cos, result_y); @@ -56,8 +60,13 @@ inline static void rotate_kv_cache_chunk_avx512(CT* current_x_values_ptr, CT* cu #endif #if defined(HAVE_AVX2) -template -inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, CT* current_y_values_ptr, float* current_rotation_coeffts_cos_ptr, float* current_rotation_coeffts_sin_ptr, size_t num_vectorized_elements_per_iteration, size_t is_underutilizing) { +template +inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, + CT* current_y_values_ptr, + float* current_rotation_coeffts_cos_ptr, + float* current_rotation_coeffts_sin_ptr, + size_t num_vectorized_elements_per_iteration, + size_t is_underutilizing) { using namespace ov::Extensions::Cpu::XARCH; auto result_x = _mm256_setzero_ps(); @@ -75,8 +84,7 @@ inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, CT* curr cache_values_x = mm256_uni_loadu_ps(current_x_values_ptr); cache_values_y = mm256_uni_loadu_ps(current_y_values_ptr); - } - else { + } else { coeffts_cos = mm256_uni_loadu_tail_ps(current_rotation_coeffts_cos_ptr, num_vectorized_elements_per_iteration); coeffts_sin = mm256_uni_loadu_tail_ps(current_rotation_coeffts_sin_ptr, num_vectorized_elements_per_iteration); @@ -85,7 +93,7 @@ inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, CT* curr } result_x = _mm256_fmadd_ps(cache_values_x, coeffts_cos, result_x); - result_x = _mm256_fnmadd_ps(cache_values_y, coeffts_sin, result_x); // negative multiply-add + result_x = _mm256_fnmadd_ps(cache_values_y, coeffts_sin, result_x); // negative multiply-add result_y = _mm256_fmadd_ps(cache_values_x, coeffts_sin, result_y); result_y = _mm256_fmadd_ps(cache_values_y, coeffts_cos, result_y); @@ -100,26 +108,31 @@ inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, CT* curr } #endif -template -inline static void rotate_kv_cache_block_hw(CT* cache_block_ptr, float* block_rotation_coefficients_ptr, size_t num_heads, size_t block_size, size_t embedding_size) { +template +inline static void rotate_kv_cache_block_hw(CT* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { #if !defined(HAVE_AVX2) && !defined(HAVE_AVX512F) OPENVINO_THROW("host CPU must support either AVX2 or AVX512 instructions"); #else bool is_underutilizing = false; -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) constexpr size_t vec_len_in_f32_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx512; -#else // HAVE_AVX2 +# else // HAVE_AVX2 constexpr size_t vec_len_in_f32_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx2; -#endif // defined(HAVE_AVX512F) +# endif // defined(HAVE_AVX512F) - size_t num_processed_elements_per_iteration = 2 * vec_len_in_f32_elts; // implementations act on pairs of cache values at once using separate registers, each elt is expanded to f32 on load + size_t num_processed_elements_per_iteration = + 2 * vec_len_in_f32_elts; // implementations act on pairs of cache values at once using separate registers, each + // elt is expanded to f32 on load size_t num_iterations = embedding_size / num_processed_elements_per_iteration; if (embedding_size >= num_processed_elements_per_iteration) { OPENVINO_ASSERT(!(num_processed_elements_per_iteration % vec_len_in_f32_elts)); - } - else { + } else { is_underutilizing = true; OPENVINO_ASSERT(!(embedding_size % 2)); num_processed_elements_per_iteration = embedding_size; @@ -131,9 +144,8 @@ inline static void rotate_kv_cache_block_hw(CT* cache_block_ptr, float* block_ro for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { // the rotation coefficients are taken to be the same for all heads float* current_rotation_coeffts_ptr = block_rotation_coefficients_ptr; - for (size_t tok_idx = 0; tok_idx < block_size; tok_idx++, - current_cache_element_ptr += embedding_size, - current_rotation_coeffts_ptr += embedding_size) { + for (size_t tok_idx = 0; tok_idx < block_size; + tok_idx++, current_cache_element_ptr += embedding_size, current_rotation_coeffts_ptr += embedding_size) { CT* current_x_values_ptr = current_cache_element_ptr; CT* current_y_values_ptr = current_cache_element_ptr + embedding_size / 2; @@ -141,37 +153,53 @@ inline static void rotate_kv_cache_block_hw(CT* cache_block_ptr, float* block_ro float* current_rotation_coeffts_sin_ptr = current_rotation_coeffts_ptr + embedding_size / 2; for (size_t iter_idx = 0; iter_idx < num_iterations; iter_idx++, - current_x_values_ptr += vec_len_in_f32_elts, - current_y_values_ptr += vec_len_in_f32_elts, - current_rotation_coeffts_cos_ptr += vec_len_in_f32_elts, - current_rotation_coeffts_sin_ptr += vec_len_in_f32_elts) { -#if defined(HAVE_AVX512F) - rotate_kv_cache_chunk_avx512(current_x_values_ptr, current_y_values_ptr, current_rotation_coeffts_cos_ptr, current_rotation_coeffts_sin_ptr, num_processed_elements_per_iteration / 2, is_underutilizing); -#else // HAVE_AVX2 - rotate_kv_cache_chunk_avx2(current_x_values_ptr, current_y_values_ptr, current_rotation_coeffts_cos_ptr, current_rotation_coeffts_sin_ptr, num_processed_elements_per_iteration / 2, is_underutilizing); -#endif // defined(HAVE_AVX512F) + current_x_values_ptr += vec_len_in_f32_elts, + current_y_values_ptr += vec_len_in_f32_elts, + current_rotation_coeffts_cos_ptr += vec_len_in_f32_elts, + current_rotation_coeffts_sin_ptr += vec_len_in_f32_elts) { +# if defined(HAVE_AVX512F) + rotate_kv_cache_chunk_avx512(current_x_values_ptr, + current_y_values_ptr, + current_rotation_coeffts_cos_ptr, + current_rotation_coeffts_sin_ptr, + num_processed_elements_per_iteration / 2, + is_underutilizing); +# else // HAVE_AVX2 + rotate_kv_cache_chunk_avx2(current_x_values_ptr, + current_y_values_ptr, + current_rotation_coeffts_cos_ptr, + current_rotation_coeffts_sin_ptr, + num_processed_elements_per_iteration / 2, + is_underutilizing); +# endif // defined(HAVE_AVX512F) } } } -#endif // !defined(HAVE_AVX512F) && !defined(HAVE_AVX2F) +#endif // !defined(HAVE_AVX512F) && !defined(HAVE_AVX2F) } -template -inline static void rotate_kv_cache_block_sw(CT* cache_block_ptr, float* block_rotation_coefficients_ptr, size_t num_heads, size_t block_size, size_t embedding_size) { +template +inline static void rotate_kv_cache_block_sw(CT* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { for (size_t tok_idx = 0; tok_idx < block_size; tok_idx++) { size_t token_offset = embedding_size * tok_idx; - CT* token_embedding_data_start_in_cache = cache_block_ptr + head_idx * embedding_size * block_size + embedding_size * tok_idx; + CT* token_embedding_data_start_in_cache = + cache_block_ptr + head_idx * embedding_size * block_size + embedding_size * tok_idx; float* token_data_start_in_rotation_coefficients = block_rotation_coefficients_ptr + token_offset; for (size_t embedding_pair_idx = 0; embedding_pair_idx < embedding_size / 2; embedding_pair_idx++) { // NB: below is the llama-style rotation (x-like values are in the first half of the embedding vector, - // y-like values are in the second half), which is different from the original RoFormer style (x- and y- values - // are interleaved), but still preserves the relative positional encoding property + // y-like values are in the second half), which is different from the original RoFormer style (x- and y- + // values are interleaved), but still preserves the relative positional encoding property CT* cache_value_0_ptr = token_embedding_data_start_in_cache + embedding_pair_idx; CT* cache_value_1_ptr = cache_value_0_ptr + (embedding_size / 2); float rotation_value_cos = token_data_start_in_rotation_coefficients[embedding_pair_idx]; - float rotation_value_sin = token_data_start_in_rotation_coefficients[embedding_pair_idx + (embedding_size / 2)]; + float rotation_value_sin = + token_data_start_in_rotation_coefficients[embedding_pair_idx + (embedding_size / 2)]; CT cache_value_0 = *cache_value_0_ptr; CT cache_value_1 = *cache_value_1_ptr; @@ -183,16 +211,24 @@ inline static void rotate_kv_cache_block_sw(CT* cache_block_ptr, float* block_ro } } -template -inline static void rotate_kv_cache_block(CT* cache_block_ptr, float* block_rotation_coefficients_ptr, size_t num_heads, size_t block_size, size_t embedding_size) { +template +inline static void rotate_kv_cache_block(CT* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { #if defined(HAVE_AVX512F) || defined(HAVE_AVX2) rotate_kv_cache_block_hw(cache_block_ptr, block_rotation_coefficients_ptr, num_heads, block_size, embedding_size); #else rotate_kv_cache_block_sw(cache_block_ptr, block_rotation_coefficients_ptr, num_heads, block_size, embedding_size); -#endif // defined(HAVE_AVX512F) || defined(HAVE_AVX2) +#endif // defined(HAVE_AVX512F) || defined(HAVE_AVX2) } -template<> -inline void rotate_kv_cache_block(uint8_t* cache_block_ptr, float* block_rotation_coefficients_ptr, size_t num_heads, size_t block_size, size_t embedding_size) { +template <> +inline void rotate_kv_cache_block(uint8_t* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { OPENVINO_THROW("cache rotation is not implemented for INT8"); } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index d0607b9f952611..4c363a29c0db02 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -138,31 +138,27 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); #ifdef HAVE_AVX2 inline __m128i get_8bit_tail_mask_for_16bit_elts(size_t num_16bit_tail_elts) { // num_tail_elts may take from 0 to 8 - static int8_t masks[9][16] = { - { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, - { -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, - { -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, - { -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, - { -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0 }, - { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0 }, - { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0 }, - { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0 }, - { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 } - }; + static int8_t masks[9][16] = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}}; return _mm_loadu_si128(reinterpret_cast<__m128i*>(masks[num_16bit_tail_elts])); } inline __m256i get_mask(int N7) { - static int32_t masks[9][8] = { - { 0, 0, 0, 0, 0, 0, 0, 0 }, - { -1, 0, 0, 0, 0, 0, 0, 0 }, - { -1, -1, 0, 0, 0, 0, 0, 0 }, - { -1, -1, -1, 0, 0, 0, 0, 0 }, - { -1, -1, -1, -1, 0, 0, 0, 0 }, - { -1, -1, -1, -1, -1, 0, 0, 0 }, - { -1, -1, -1, -1, -1, -1, 0, 0 }, - { -1, -1, -1, -1, -1, -1, -1, 0 }, - { -1, -1, -1, -1, -1, -1, -1, -1 } - }; + static int32_t masks[9][8] = {{0, 0, 0, 0, 0, 0, 0, 0}, + {-1, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1}}; return _mm256_loadu_si256(reinterpret_cast<__m256i*>(masks[N7])); } @@ -223,7 +219,7 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); return bf16_o; } - inline void mm256_uni_storeu_ps(ov::bfloat16 *addr, __m256 xps) { + inline void mm256_uni_storeu_ps(ov::bfloat16* addr, __m256 xps) { __m128i bf16_o = __convert_avx2_packed_float_to_packed_ov_bfloat16(xps); _mm_storeu_si128(reinterpret_cast<__m128i *>(addr), bf16_o); } @@ -239,19 +235,18 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); return _mm256_maskstore_ps(addr, mask, v); } - inline void mm256_uni_storeu_tail_ps(ov::float16 *addr, __m256 v, size_t count) { + inline void mm256_uni_storeu_tail_ps(ov::float16* addr, __m256 v, size_t count) { auto mask = get_8bit_tail_mask_for_16bit_elts(count); __m128i vec_f16 = _mm256_cvtps_ph(v, 0); return _mm_maskmoveu_si128(vec_f16, mask, reinterpret_cast(addr)); } - inline void mm256_uni_storeu_tail_ps(ov::bfloat16 *addr, __m256 v, size_t count) { + inline void mm256_uni_storeu_tail_ps(ov::bfloat16* addr, __m256 v, size_t count) { auto mask = get_8bit_tail_mask_for_16bit_elts(count); __m128i bf16_o = __convert_avx2_packed_float_to_packed_ov_bfloat16(v); return _mm_maskmoveu_si128(bf16_o, mask, reinterpret_cast(addr)); } - inline void hsum(__m256& x) { __m256 y; // x: 0 1 2 3 4 5 6 7 y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 5ee351333c9ff9..8a3f8f8d5dfdba 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -13,20 +13,20 @@ # include #endif -#include "openvino/core/type/bfloat16.hpp" -#include "openvino/core/type/float16.hpp" -#include "openvino/core/parallel.hpp" +#include "attn_memcpy.hpp" +#include "attn_quant.hpp" +#include "attn_quant_kernel.hpp" +#include "cache_rotation.hpp" +#include "common.hpp" #include "executor_pa.hpp" #include "executor_pa_common.hpp" -#include "common.hpp" -#include "attn_quant_kernel.hpp" +#include "nodes/kernels/x64/brgemm_kernel.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/type/float16.hpp" #include "softmax_kernel.hpp" #include "transpose_kernel.hpp" #include "utils/plain_tensor.hpp" -#include "attn_memcpy.hpp" -#include "attn_quant.hpp" -#include "nodes/kernels/x64/brgemm_kernel.hpp" -#include "cache_rotation.hpp" namespace ov { namespace Extensions { @@ -770,14 +770,16 @@ static void pack_32NxK(float* dst, T* src, float* tmp, size_t N, size_t K, size_ OPENVINO_THROW("pack_32NxK: should not be called."); } -template -void rotate_kv_cache(PlainTensor& key_cache, const PlainTensor& rotation_coefficients, const PlainTensor& rotated_block_indices) { +template +void rotate_kv_cache(PlainTensor& key_cache, + const PlainTensor& rotation_coefficients, + const PlainTensor& rotated_block_indices) { size_t num_rotated_blocks = rotated_block_indices.size(0); size_t num_blocks_in_total = key_cache.size(0); size_t block_size = key_cache.size(2); int32_t* rotated_block_indices_data = rotated_block_indices.ptr(); - size_t num_heads = key_cache.size(1); // H; - size_t embedding_size = key_cache.size(3); // S; + size_t num_heads = key_cache.size(1); // H; + size_t embedding_size = key_cache.size(3); // S; size_t head_chunk_size = block_size * embedding_size; for (size_t i = 0; i < num_rotated_blocks; i++) { @@ -1157,11 +1159,9 @@ struct MHAHelper { cvt_copy(output_emb.ptr(pq, h * _SV), _output.ptr(ithr, pq, h), _SV); } - - - // compute one token, loop along batch, head dimensions and kv_len, it's special for very long kv_len with small batch tokens. - // It will assume NO mixture execution of first and second token. - // all tensors such as query... have batch dimension which is DIFFERENT from above + // compute one token, loop along batch, head dimensions and kv_len, it's special for very long kv_len with small + // batch tokens. It will assume NO mixture execution of first and second token. all tensors such as query... have + // batch dimension which is DIFFERENT from above // query: [B, H, L, S] // key_cache: [block_number, H, _block_size, S] // value_cache: [block_number, H, _block_size, Sv] @@ -1203,16 +1203,20 @@ struct MHAHelper { _gemv->tile_config(); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { - (*_gemv)(query.ptr(b, h, pq), key_cache.ptr(block_number, hk), - _weight_bhl.ptr(b, h, pq) + pk); + (*_gemv)(query.ptr(b, h, pq), + key_cache.ptr(block_number, hk), + _weight_bhl.ptr(b, h, pq) + pk); } } _gemv->tile_release(); } else { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { - dot_product_block(query.ptr(b, h, pq), key_cache.ptr(block_number, hk), - _weight_bhl.ptr(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk)); + dot_product_block(query.ptr(b, h, pq), + key_cache.ptr(block_number, hk), + _weight_bhl.ptr(b, h, pq) + pk, + _S, + std::min(_block_size, context_len - pk)); } } } @@ -1289,7 +1293,6 @@ struct MHAHelper { } }; - template struct MHA { MHAHelper& _helper; @@ -1392,7 +1395,6 @@ struct MHA { MHA(MHAHelper& helper) : _helper(helper) {} - // one loop to handle first and second tokens void exec_loop_mixed(const PlainTensor& q, PlainTensor& k_cache, @@ -1542,11 +1544,33 @@ struct MHA { auto nthr = static_cast(parallel_get_max_threads()); if (past_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) { - exec_loop_mixed(query, present_key, present_value, output_emb, output_score, max_context_len, past_lens, subsequence_begins, - block_indices, block_indices_begins, alibi_slopes, rotation_coefficients, rotated_block_indices); + exec_loop_mixed(query, + present_key, + present_value, + output_emb, + output_score, + max_context_len, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + alibi_slopes, + rotation_coefficients, + rotated_block_indices); } else { - _helper.exec_loop_bhl(query, present_key, present_value, output_emb, output_score, max_context_len, past_lens, subsequence_begins, - block_indices, block_indices_begins, alibi_slopes, rotation_coefficients, rotated_block_indices); + _helper.exec_loop_bhl(query, + present_key, + present_value, + output_emb, + output_score, + max_context_len, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + alibi_slopes, + rotation_coefficients, + rotated_block_indices); } } }; @@ -1559,11 +1583,25 @@ struct AttentionExecutor : public PagedAttentionExecutor { AttentionExecutor() : _kernel(_helper) {} - void init(const std::vector& inputs, const std::vector& outputs, PlainTensor& q, PlainTensor& k, PlainTensor& v, PlainTensor& k_cache, - PlainTensor& v_cache, PlainTensor& past_lens, PlainTensor& subsequence_begins, PlainTensor& block_indices, PlainTensor& block_indices_begins, - float& scale, size_t& sliding_window, PlainTensor& alibi_slopes, size_t& max_context_len, - PlainTensor& rotation_coefficients, PlainTensor& rotated_block_indices, - PlainTensor& output_emb, PlainTensor& output_score) { + void init(const std::vector& inputs, + const std::vector& outputs, + PlainTensor& q, + PlainTensor& k, + PlainTensor& v, + PlainTensor& k_cache, + PlainTensor& v_cache, + PlainTensor& past_lens, + PlainTensor& subsequence_begins, + PlainTensor& block_indices, + PlainTensor& block_indices_begins, + float& scale, + size_t& sliding_window, + PlainTensor& alibi_slopes, + size_t& max_context_len, + PlainTensor& rotation_coefficients, + PlainTensor& rotated_block_indices, + PlainTensor& output_emb, + PlainTensor& output_score) { q.reset(inputs[ID_Q]); // [B_token, H * S] k.reset(inputs[ID_K]); v.reset(inputs[ID_V]); @@ -1631,8 +1669,9 @@ struct AttentionExecutor : public PagedAttentionExecutor { } if (rotated_block_indices) { - // Only K entries are needed to be rotated, since position is encoded at the Q^T @ (effective_RoPE_matrix) @ K matrix multiplication - rotation_coefficients.assert_dims({ S * rotated_block_indices.size(0) * block_size }); + // Only K entries are needed to be rotated, since position is encoded at the Q^T @ (effective_RoPE_matrix) @ + // K matrix multiplication + rotation_coefficients.assert_dims({S * rotated_block_indices.size(0) * block_size}); } output_emb.assert_dims({B_token, H * SV}); output_emb = output_emb.reshape({B_token, 1, H * SV}); @@ -1680,16 +1719,40 @@ struct AttentionExecutor : public PagedAttentionExecutor { PlainTensor output_emb; PlainTensor output_score; - init(inputs, outputs, q, k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins, - scale, sliding_window, alibi_slopes, max_context_len, - rotation_coefficients, rotated_block_indices, - output_emb, output_score); + init(inputs, + outputs, + q, + k, + v, + k_cache, + v_cache, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + rotation_coefficients, + rotated_block_indices, + output_emb, + output_score); concat_pastkv(k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins); - _kernel(q, k_cache, v_cache, output_emb, output_score, max_context_len, past_lens, subsequence_begins, block_indices, - block_indices_begins, - alibi_slopes, - rotation_coefficients, rotated_block_indices); + _kernel(q, + k_cache, + v_cache, + output_emb, + output_score, + max_context_len, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + alibi_slopes, + rotation_coefficients, + rotated_block_indices); } }; #endif diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 1d043cc175faac..97be1eea79632c 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -80,7 +80,8 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { config.inConfs[PagedAttentionExecutor::ID_V].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( rtPrecision, getInputShapeAtPort(PagedAttentionExecutor::ID_V))); - OPENVINO_ASSERT(orgInputNumber == 15 || orgInputNumber == 13, "The input number of PagedAttention should be 13 or 15."); + OPENVINO_ASSERT(orgInputNumber == 15 || orgInputNumber == 13, + "The input number of PagedAttention should be 13 or 15."); // kvcache, float, [] auto past_kv_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); config.inConfs[PagedAttentionExecutor::ID_KCACHE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( @@ -114,14 +115,17 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { if (orgInputNumber == 15) { // rotation_coefficients, float, [num_rotated_blocks * block_size || 0] - config.inConfs[PagedAttentionExecutor::ID_ROTATION_COEFFICIENTS].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::f32, getInputShapeAtPort(PagedAttentionExecutor::ID_ROTATION_COEFFICIENTS))); + config.inConfs[PagedAttentionExecutor::ID_ROTATION_COEFFICIENTS].setMemDesc( + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(ov::element::f32, + getInputShapeAtPort(PagedAttentionExecutor::ID_ROTATION_COEFFICIENTS))); // rotated_block_indices, int, [num_rotated_blocks || 0] - config.inConfs[PagedAttentionExecutor::ID_ROTATED_BLOCK_INDICES].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_ROTATED_BLOCK_INDICES))); + config.inConfs[PagedAttentionExecutor::ID_ROTATED_BLOCK_INDICES].setMemDesc( + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(ov::element::i32, + getInputShapeAtPort(PagedAttentionExecutor::ID_ROTATED_BLOCK_INDICES))); } - config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( rtPrecision, getOutputShapeAtPort(0))); config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( diff --git a/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp b/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp index eb9c76ad12e201..c800ea675d276c 100644 --- a/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp +++ b/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp @@ -2,13 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 // - -#include -#include #include +#include #include + #include #include +#include // the includes in the block below are necessary in order for the common.hpp header to be // instantiated correctly @@ -17,21 +17,20 @@ # include #endif #include "kernels/scaled_attn/common.hpp" - -#include "utils/plain_tensor.hpp" #include "nodes/kernels/scaled_attn/cache_rotation.hpp" #include "perf_count.h" +#include "utils/plain_tensor.hpp" using namespace ov::intel_cpu; -template +template using Rank2Matrix = std::vector>; -template +template using Rank3Matrix = std::vector>>; // Expected layout: [block_size, embedding_size] -template +template std::shared_ptr get_block_memory(size_t block_size, size_t embedding_size, const Rank2Matrix& init_values) { auto mem = std::shared_ptr(new T[block_size * embedding_size]); if (!init_values.empty()) { @@ -47,8 +46,11 @@ std::shared_ptr get_block_memory(size_t block_size, size_t embedding_size, } // Expected layout: [num_heads, block_size, embedding_size] -template -std::shared_ptr get_block_memory(size_t num_heads, size_t block_size, size_t embedding_size, const Rank3Matrix& init_values) { +template +std::shared_ptr get_block_memory(size_t num_heads, + size_t block_size, + size_t embedding_size, + const Rank3Matrix& init_values) { auto mem = std::shared_ptr(new T[num_heads * block_size * embedding_size]); if (!init_values.empty()) { assert(init_values.size() == num_heads); @@ -65,9 +67,11 @@ std::shared_ptr get_block_memory(size_t num_heads, size_t block_size, size_ return mem; } - -template -Rank3Matrix get_matrix_from_mem(std::shared_ptr mem_ptr, size_t num_heads, size_t block_size, size_t embedding_size) { +template +Rank3Matrix get_matrix_from_mem(std::shared_ptr mem_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { Rank3Matrix retval(num_heads); for (size_t i = 0; i < num_heads; i++) { retval[i].resize(block_size); @@ -85,7 +89,7 @@ Rank3Matrix get_matrix_from_mem(std::shared_ptr mem_ptr, size_t num_head return retval; } -template +template void compare_with_tolerance(const Rank3Matrix& test_data, const Rank3Matrix& ref_data, T abs_err) { ASSERT_EQ(test_data.size(), ref_data.size()); ASSERT_GT(test_data.size(), 0); @@ -101,54 +105,61 @@ void compare_with_tolerance(const Rank3Matrix& test_data, const Rank3Matrix abs_err) || (diff < -abs_err)) { - ADD_FAILURE() << std::setprecision(8) << "diff " << diff << " exceeding atol " << abs_err << " at idx [" << i << ";" << j << ";" << k << "] --- test " << test_data[i][j][k] << ", ref " << ref_data[i][j][k]; + ADD_FAILURE() << std::setprecision(8) << "diff " << diff << " exceeding atol " << abs_err + << " at idx [" << i << ";" << j << ";" << k << "] --- test " << test_data[i][j][k] + << ", ref " << ref_data[i][j][k]; } } } } } - template static T get_tolerance() { return T{}; } template <> -float get_tolerance() { return 1e-6; }; +float get_tolerance() { + return 1e-6; +}; template <> -ov::float16 get_tolerance() { return ov::float16{5e-3}; }; +ov::float16 get_tolerance() { + return ov::float16{5e-3}; +}; template <> -ov::bfloat16 get_tolerance() { return ov::bfloat16{4e-2}; }; +ov::bfloat16 get_tolerance() { + return ov::bfloat16{4e-2}; +}; -template +template class CacheRotationKernelTest : public ::testing::Test { - public: +public: void SetUp() override { Rank3Matrix values_before_rotation = { { - { 1.0, 1.0, 1.0, 1.0 }, - { 1.0, 1.0, 1.0, 1.0 }, - { 1.0, 1.0, 1.0, 1.0 }, - { 1.0, 1.0, 1.0, 1.0 }, + {1.0, 1.0, 1.0, 1.0}, + {1.0, 1.0, 1.0, 1.0}, + {1.0, 1.0, 1.0, 1.0}, + {1.0, 1.0, 1.0, 1.0}, }, { - { -2.0, -2.0, -2.0, -2.0 }, - { 2.0, 2.0, 2.0, 2.0 }, - { -1.0, 2.0, -3.0, 4.0 }, - { 2.0, 2.0, 2.0, 2.0 }, + {-2.0, -2.0, -2.0, -2.0}, + {2.0, 2.0, 2.0, 2.0}, + {-1.0, 2.0, -3.0, 4.0}, + {2.0, 2.0, 2.0, 2.0}, }, }; cache_mem_ptr = get_block_memory(num_heads, block_size, embedding_size, values_before_rotation); - Rank2Matrix rotation_values = { - { 0.5, 0.70710678, 0.86602540, -0.70710678 }, - { 0.86602540, 1.0, 0.5, 0.0 }, - { -0.70710678, 0.0, 0.70710678, 1.0}, - { 0.0, 0.6, -1.0, -0.8 }, - }; + Rank2Matrix rotation_values = { + {0.5, 0.70710678, 0.86602540, -0.70710678}, + {0.86602540, 1.0, 0.5, 0.0}, + {-0.70710678, 0.0, 0.70710678, 1.0}, + {0.0, 0.6, -1.0, -0.8}, + }; rotation_coefficients_mem_ptr = get_block_memory(block_size, embedding_size, rotation_values); } @@ -159,16 +170,16 @@ class CacheRotationKernelTest : public ::testing::Test { std::shared_ptr rotation_coefficients_mem_ptr; Rank3Matrix ref_values_after_rotation = { { - { -0.36602540, 1.41421356, 1.36602540, 0.00000000 }, - { 0.36602540, 1.00000000, 1.36602540, 1.00000000 }, - { -1.41421356, -1.00000000, 0.00000000, 1.00000000 }, - { 1.00000000, 1.40000000, -1.00000000, -0.20000000 }, + {-0.36602540, 1.41421356, 1.36602540, 0.00000000}, + {0.36602540, 1.00000000, 1.36602540, 1.00000000}, + {-1.41421356, -1.00000000, 0.00000000, 1.00000000}, + {1.00000000, 1.40000000, -1.00000000, -0.20000000}, }, { - { 0.73205081, -2.82842712, -2.73205081, 0.00000000 }, - { 0.73205081, 2.00000000, 2.73205081, 2.00000000 }, - { 2.82842712, -4.00000000, 1.41421356, 2.00000000 }, - { 2.00000000, 2.80000000, -2.00000000, -0.40000000 }, + {0.73205081, -2.82842712, -2.73205081, 0.00000000}, + {0.73205081, 2.00000000, 2.73205081, 2.00000000}, + {2.82842712, -4.00000000, 1.41421356, 2.00000000}, + {2.00000000, 2.80000000, -2.00000000, -0.40000000}, }, }; @@ -180,15 +191,18 @@ class CacheRotationKernelTest : public ::testing::Test { engine.seed(0); std::uniform_real_distribution rng(-2.0, 2.0); - auto raw_mem_ptr_sw = cache_block_mem_sw.get(); auto raw_rotation_coefficients_mem_ptr = rotation_coeffts_block_mem.get(); - auto generate_fn = [&]() { return TypeParam(rng(engine)); }; + auto generate_fn = [&]() { + return TypeParam(rng(engine)); + }; std::generate(raw_mem_ptr_sw, raw_mem_ptr_sw + num_heads * block_size * embedding_size, generate_fn); // coeffts are now not strictly sine-cosine pairs, but it does not matter for the kernels - std::generate(raw_rotation_coefficients_mem_ptr, raw_rotation_coefficients_mem_ptr + block_size * embedding_size, generate_fn); + std::generate(raw_rotation_coefficients_mem_ptr, + raw_rotation_coefficients_mem_ptr + block_size * embedding_size, + generate_fn); auto cache_block_mem_hw = get_block_memory(num_heads, block_size, embedding_size, Rank3Matrix{}); auto raw_mem_ptr_hw = cache_block_mem_hw.get(); @@ -197,12 +211,20 @@ class CacheRotationKernelTest : public ::testing::Test { ov::intel_cpu::PerfCount counter; { ov::intel_cpu::PerfHelper helper(counter); - rotate_kv_cache_block_hw(raw_mem_ptr_hw, rotation_coeffts_block_mem.get(), num_heads, block_size, embedding_size); + rotate_kv_cache_block_hw(raw_mem_ptr_hw, + rotation_coeffts_block_mem.get(), + num_heads, + block_size, + embedding_size); } { ov::intel_cpu::PerfHelper helper(counter); - rotate_kv_cache_block_sw(raw_mem_ptr_sw, rotation_coeffts_block_mem.get(), num_heads, block_size, embedding_size); + rotate_kv_cache_block_sw(raw_mem_ptr_sw, + rotation_coeffts_block_mem.get(), + num_heads, + block_size, + embedding_size); } auto sw_values_after_rotation = get_matrix_from_mem(cache_block_mem_sw, num_heads, block_size, embedding_size); @@ -219,26 +241,30 @@ TYPED_TEST(CacheRotationKernelTest, SWBlockRotationGivesReferenceResults) { auto raw_cache_mem_ptr = this->cache_mem_ptr.get(); auto raw_rotation_coefficients_mem_ptr = this->rotation_coefficients_mem_ptr.get(); - rotate_kv_cache_block_sw(raw_cache_mem_ptr, raw_rotation_coefficients_mem_ptr, this->num_heads, this->block_size, this->embedding_size); + rotate_kv_cache_block_sw(raw_cache_mem_ptr, + raw_rotation_coefficients_mem_ptr, + this->num_heads, + this->block_size, + this->embedding_size); - auto test_values_after_rotation = get_matrix_from_mem(this->cache_mem_ptr, this->num_heads, this->block_size, this->embedding_size); + auto test_values_after_rotation = + get_matrix_from_mem(this->cache_mem_ptr, this->num_heads, this->block_size, this->embedding_size); compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, get_tolerance()); } -enum class TargetInstructionSet { - AVX2, - AVX512 -}; +enum class TargetInstructionSet { AVX2, AVX512 }; MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") { - if (ref_container.size() < n || arg.size() < n) return false; - if (ref_container.size() != arg.size()) return false; + if (ref_container.size() < n || arg.size() < n) + return false; + if (ref_container.size() != arg.size()) + return false; bool is_ok = true; - for (size_t i = 0; i < n; i++) - { - if (!::testing::ExplainMatchResult(::testing::FloatNear(float(arg[i]), abs_err), float(ref_container[i]), result_listener)) - { + for (size_t i = 0; i < n; i++) { + if (!::testing::ExplainMatchResult(::testing::FloatNear(float(arg[i]), abs_err), + float(ref_container[i]), + result_listener)) { *result_listener << " for element at idx " << i << '\n'; is_ok = false; } @@ -246,67 +272,164 @@ MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") { return is_ok; } -class CacheRotationHWKernelTest: public ::testing::TestWithParam> { +class CacheRotationHWKernelTest : public ::testing::TestWithParam> { protected: constexpr static size_t MAX_CHUNK_SIZE_IN_ELEMENTS = 16; - template + template using MemChunk = std::array; - template + template void test_chunk_rotation_for_type() { auto instruction_set = std::get<0>(GetParam()); auto num_elements_to_process = std::get<1>(GetParam()); - - - MemChunk chunk_x = { -0.76777814, 0.97583583, -0.23619731, 0.19022397, 0.56691264, 0.64870757, 0.63334306, 1.97307894, - 0.72495168, 1.22328697, -0.6005607, 0.17189973, -0.92268487, 0.40205632, 0.85996431, 1.70078315}; - - MemChunk chunk_y = { 1.68812157, -0.90722836, 0.58474063, -0.64561766, 0.62651501, 1.55990472, 0.41571189, 0.38366555, - 0.09841767, 0.02218336, -0.07657361, 1.6062845, -1.08282323, -0.92034808, -1.48428038, 0.43501142}; - - MemChunk chunk_cos = { -0.87461971, 0.95630476, 0.08715574, 0.8480481, -0.9612617, 0.27563736, 0.97437006, 0.66913061, - -0.89100652, 0.98480775, -0.7313537, -0.2419219, 0.10452846, 0.70710678, -0.32556815, -0.2923717 }; - - MemChunk chunk_sin = { -0.48480962, -0.2923717, 0.9961947, 0.52991926, 0.27563736, -0.9612617, -0.22495105, 0.74314483, - 0.4539905, -0.17364818, -0.68199836, -0.97029573, -0.9945219, -0.70710678, -0.94551858, 0.95630476 }; + MemChunk chunk_x = {-0.76777814, + 0.97583583, + -0.23619731, + 0.19022397, + 0.56691264, + 0.64870757, + 0.63334306, + 1.97307894, + 0.72495168, + 1.22328697, + -0.6005607, + 0.17189973, + -0.92268487, + 0.40205632, + 0.85996431, + 1.70078315}; + + MemChunk chunk_y = {1.68812157, + -0.90722836, + 0.58474063, + -0.64561766, + 0.62651501, + 1.55990472, + 0.41571189, + 0.38366555, + 0.09841767, + 0.02218336, + -0.07657361, + 1.6062845, + -1.08282323, + -0.92034808, + -1.48428038, + 0.43501142}; + + MemChunk chunk_cos = {-0.87461971, + 0.95630476, + 0.08715574, + 0.8480481, + -0.9612617, + 0.27563736, + 0.97437006, + 0.66913061, + -0.89100652, + 0.98480775, + -0.7313537, + -0.2419219, + 0.10452846, + 0.70710678, + -0.32556815, + -0.2923717}; + + MemChunk chunk_sin = {-0.48480962, + -0.2923717, + 0.9961947, + 0.52991926, + 0.27563736, + -0.9612617, + -0.22495105, + 0.74314483, + 0.4539905, + -0.17364818, + -0.68199836, + -0.97029573, + -0.9945219, + -0.70710678, + -0.94551858, + 0.95630476}; MemChunk ref_chunk_cos = chunk_cos; MemChunk ref_chunk_sin = chunk_sin; - MemChunk ref_chunk_x = { 1.48993147, 0.66794854, -0.60310147, 0.50344431, -0.71764235, 1.6782847, 0.71062535, 1.03512844, - -0.69061736, 1.20855459, 0.38699921, 1.51698468, -1.17333824, -0.36648762, -1.68339166, -0.91326436 }; - - MemChunk ref_chunk_y = { -1.10423816, -1.15289358, -0.184335, -0.44671148, -0.44598258, -0.19360973, 0.26258603, 1.72300577, - 0.24143039, -0.19057521, 0.46558381, -0.55538896, 0.80444446, -0.93508112, -0.32987781, 1.49928198 }; + MemChunk ref_chunk_x = {1.48993147, + 0.66794854, + -0.60310147, + 0.50344431, + -0.71764235, + 1.6782847, + 0.71062535, + 1.03512844, + -0.69061736, + 1.20855459, + 0.38699921, + 1.51698468, + -1.17333824, + -0.36648762, + -1.68339166, + -0.91326436}; + + MemChunk ref_chunk_y = {-1.10423816, + -1.15289358, + -0.184335, + -0.44671148, + -0.44598258, + -0.19360973, + 0.26258603, + 1.72300577, + 0.24143039, + -0.19057521, + 0.46558381, + -0.55538896, + 0.80444446, + -0.93508112, + -0.32987781, + 1.49928198}; // unprocessed elements should remain untouched - std::copy(chunk_x.begin() + num_elements_to_process, chunk_x.end(), ref_chunk_x.begin() + num_elements_to_process); - std::copy(chunk_y.begin() + num_elements_to_process, chunk_y.end(), ref_chunk_y.begin() + num_elements_to_process); - - switch(instruction_set) { + std::copy(chunk_x.begin() + num_elements_to_process, + chunk_x.end(), + ref_chunk_x.begin() + num_elements_to_process); + std::copy(chunk_y.begin() + num_elements_to_process, + chunk_y.end(), + ref_chunk_y.begin() + num_elements_to_process); + + switch (instruction_set) { using namespace ov::Extensions::Cpu::XARCH; - case TargetInstructionSet::AVX2: - rotate_kv_cache_chunk_avx2(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), num_elements_to_process, /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx2); - break; - case TargetInstructionSet::AVX512: - rotate_kv_cache_chunk_avx512(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), num_elements_to_process, /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx512); - break; - default: - FAIL() << "unknown target instruction set"; + case TargetInstructionSet::AVX2: + rotate_kv_cache_chunk_avx2(chunk_x.data(), + chunk_y.data(), + chunk_cos.data(), + chunk_sin.data(), + num_elements_to_process, + /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx2); + break; + case TargetInstructionSet::AVX512: + rotate_kv_cache_chunk_avx512(chunk_x.data(), + chunk_y.data(), + chunk_cos.data(), + chunk_sin.data(), + num_elements_to_process, + /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx512); + break; + default: + FAIL() << "unknown target instruction set"; } std::string type_name = ov::element::from().to_string(); - EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, get_tolerance(), num_elements_to_process)) << ", element type is: " << type_name; - EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, get_tolerance(), num_elements_to_process)) << ", element type is: " << type_name; + EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, get_tolerance(), num_elements_to_process)) + << ", element type is: " << type_name; + EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, get_tolerance(), num_elements_to_process)) + << ", element type is: " << type_name; EXPECT_EQ(chunk_cos, ref_chunk_cos) << ", element type is: " << type_name; EXPECT_EQ(chunk_sin, ref_chunk_sin) << ", element type is: " << type_name; } }; - TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) { test_chunk_rotation_for_type(); test_chunk_rotation_for_type(); @@ -314,31 +437,44 @@ TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) { } auto TEST_STRUCT_TO_NAME_FN = [](const testing::TestParamInfo& info) { - size_t num_elts = std::get<1>(info.param); - switch(std::get<0>(info.param)) { - case TargetInstructionSet::AVX2: - return std::string("avx2-") + std::to_string(num_elts); - case TargetInstructionSet::AVX512: - return std::string("avx512-") + std::to_string(num_elts); - } - return std::string("unknown"); + size_t num_elts = std::get<1>(info.param); + switch (std::get<0>(info.param)) { + case TargetInstructionSet::AVX2: + return std::string("avx2-") + std::to_string(num_elts); + case TargetInstructionSet::AVX512: + return std::string("avx512-") + std::to_string(num_elts); + } + return std::string("unknown"); }; -INSTANTIATE_TEST_SUITE_P(AVX2, CacheRotationHWKernelTest, ::testing::Combine(::testing::Values(TargetInstructionSet::AVX2), ::testing::Range(size_t(0), ov::Extensions::Cpu::XARCH::vec_len_f32_avx2 + 1)), TEST_STRUCT_TO_NAME_FN); -INSTANTIATE_TEST_SUITE_P(AVX512, CacheRotationHWKernelTest, ::testing::Combine(::testing::Values(TargetInstructionSet::AVX512), ::testing::Range(size_t(0), ov::Extensions::Cpu::XARCH::vec_len_f32_avx512 + 1)), TEST_STRUCT_TO_NAME_FN); +INSTANTIATE_TEST_SUITE_P(AVX2, + CacheRotationHWKernelTest, + ::testing::Combine(::testing::Values(TargetInstructionSet::AVX2), + ::testing::Range(size_t(0), + ov::Extensions::Cpu::XARCH::vec_len_f32_avx2 + 1)), + TEST_STRUCT_TO_NAME_FN); +INSTANTIATE_TEST_SUITE_P(AVX512, + CacheRotationHWKernelTest, + ::testing::Combine(::testing::Values(TargetInstructionSet::AVX512), + ::testing::Range(size_t(0), + ov::Extensions::Cpu::XARCH::vec_len_f32_avx512 + 1)), + TEST_STRUCT_TO_NAME_FN); TYPED_TEST(CacheRotationKernelTest, HWBlockRotationGivesReferenceResults) { auto raw_cache_mem_ptr = this->cache_mem_ptr.get(); auto raw_rotation_coefficients_mem_ptr = this->rotation_coefficients_mem_ptr.get(); - rotate_kv_cache_block_hw(raw_cache_mem_ptr, raw_rotation_coefficients_mem_ptr, this->num_heads, this->block_size, this->embedding_size); + rotate_kv_cache_block_hw(raw_cache_mem_ptr, + raw_rotation_coefficients_mem_ptr, + this->num_heads, + this->block_size, + this->embedding_size); - auto test_values_after_rotation = get_matrix_from_mem(this->cache_mem_ptr, this->num_heads, this->block_size, this->embedding_size); + auto test_values_after_rotation = + get_matrix_from_mem(this->cache_mem_ptr, this->num_heads, this->block_size, this->embedding_size); compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, get_tolerance()); } - - TYPED_TEST(CacheRotationKernelTest, HWBlockRotationIsSimilarToSW) { // short case this->test_block_hw_vs_sw(/* num_heads = */ 4, /* embedding_size = */ 64, /* block_size = */ 2); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index a1efa456aec9c0..c9743f6a4d8ffe 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -21,7 +21,9 @@ struct paged_attention : public primitive_base { paged_attention(const primitive_id& id, const std::vector& inputs) : primitive_base(id, inputs) { - OPENVINO_ASSERT((inputs.size() == 13) || (inputs.size() == 15), "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size()); + OPENVINO_ASSERT((inputs.size() == 13) || (inputs.size() == 15), + "[GPU] Unexpected inputs number for PagedAttention primitive: ", + inputs.size()); } bool operator==(const primitive& rhs) const override { diff --git a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h index 5d6b538bfb1994..af653102f28da9 100644 --- a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h @@ -62,9 +62,12 @@ class typed_primitive_inst : public typed_primitive_inst_base

prefill_network; diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 5bb7c7dd38bd74..a013a8ad315dfc 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -62,7 +62,8 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared OPENVINO_ASSERT(alibi_const != nullptr); prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0; - std::shared_ptr rotation_coefficients_const = std::dynamic_pointer_cast(op->get_input_node_shared_ptr(rotation_coefficients_idx)); + std::shared_ptr rotation_coefficients_const = + std::dynamic_pointer_cast(op->get_input_node_shared_ptr(rotation_coefficients_idx)); OPENVINO_ASSERT(rotation_coefficients_const != nullptr); prim.has_rotation_coefficients = ov::shape_size(rotation_coefficients_const->get_output_shape(0)) > 0;