-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cache rotation inputs and CPU kernel implementation for cache rotation #27088
base: master
Are you sure you want to change the base?
Conversation
2a172b2
to
c071571
Compare
pa_arguments.insert(pa_arguments.begin() + 13, v0::Constant::create(element::f32, Shape{0}, {})); | ||
pa_arguments.insert(pa_arguments.begin() + 14, v0::Constant::create(element::i32, Shape{0}, {})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you make these inputs really optional, these two lines are not required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/core/src/op/paged_attention.cpp
Outdated
get_input_partial_shape(13).rank().is_dynamic() || | ||
get_input_partial_shape(13).rank().get_length() == 0 || | ||
get_input_partial_shape(13).rank().get_length() == 1, | ||
"Input `rotation_coefficients` should either have an empty shape or rank 1, but it has rank ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Input `rotation_coefficients` should either have an empty shape or rank 1, but it has rank ", | |
"Input `rotation_coefficients` should either have rank 1 or omitted, but it has rank ", |
"Empty" shape means [0]
here, which have rank 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/core/src/op/paged_attention.cpp
Outdated
NODE_VALIDATION_CHECK( | ||
this, | ||
get_input_partial_shape(13).rank().is_dynamic() || | ||
get_input_partial_shape(13).rank().get_length() == 0 || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_input_partial_shape(13).rank().get_length() == 0 || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/core/src/op/paged_attention.cpp
Outdated
get_input_partial_shape(14).rank().get_length() == 0 || | ||
get_input_partial_shape(14).rank().get_length() == 1, | ||
"Input `rotated_block_indices` should either have an empty shape or rank 1 but it has rank ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same comment are applicable here as for input 13 above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -1576,6 +1591,11 @@ struct AttentionExecutor : public PagedAttentionExecutor { | |||
if (alibi_slopes) { | |||
alibi_slopes.assert_dims({H}); | |||
} | |||
|
|||
if (rotated_block_indices) { | |||
// Rotation, and cache eviction, is limited to cases when Q, K and V embedding sizes are equal, e.g. S == Sv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have cases where they are not: minicpm-3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed - realized that we don't need that limitation for cache rotation since we only rotate the K values
@@ -58,6 +59,10 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared | |||
OPENVINO_ASSERT(alibi_const != nullptr); | |||
prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0; | |||
|
|||
std::shared_ptr<ov::op::v0::Constant> rotation_coefficients_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(op->get_input_node_shared_ptr(rotation_coefficients_idx)); | |||
OPENVINO_ASSERT(rotation_coefficients_const != nullptr); | |||
prim.has_rotation_coefficients = ov::shape_size(alibi_const->get_output_shape(0)) > 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alibi_const
shouldn't be used here -- bad copy&paste?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thanks.
d90e212
to
ed46cfe
Compare
@luo-cheng2021 Please review CPU PA changes. |
src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp
Outdated
Show resolved
Hide resolved
CT cache_value_1 = *cache_value_1_ptr; | ||
|
||
*cache_value_0_ptr = cache_value_0 * rotation_value_cos - cache_value_1 * rotation_value_sin; | ||
*cache_value_1_ptr = cache_value_0 * rotation_value_sin + cache_value_1 * rotation_value_cos; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the algorithm same with the following code?
openvino/src/plugins/intel_cpu/src/nodes/rope.cpp
Lines 158 to 161 in c4d6d2b
auto src0 = src[i]; | |
auto src1 = src[i + half_rotary_dims]; | |
dst[i] = cos[i] * src0 - sin[i] * src1; | |
dst[i + half_rotary_dims] = cos[i + half_rotary_dims] * src1 + sin[i + half_rotary_dims] * src0; |
If so, the following code can be used as reference:
openvino/src/plugins/intel_cpu/src/nodes/rope.cpp
Lines 35 to 102 in c4d6d2b
static std::shared_ptr<kernel::JitKernelBase> createJitKernel(const jit_rotary_compile_params& param, bool check_vec_size2 = false) { | |
std::shared_ptr<kernel::JitKernelBase> res; | |
MAYBE_UNUSED(param); | |
MAYBE_UNUSED(check_vec_size2); | |
#if defined(OPENVINO_ARCH_X86_64) | |
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) { | |
bool flag = true; | |
if (check_vec_size2) { | |
auto vec_size = jit_rotary_kernel<dnnl::impl::cpu::x64::avx512_core>::vec_size; | |
if (param.rotary_ndims % (vec_size * 2) != 0) | |
flag = false; | |
} | |
if (flag) | |
res = std::make_shared<jit_rotary_kernel<dnnl::impl::cpu::x64::avx512_core>>(param); | |
} else if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) { | |
bool flag = true; | |
if (check_vec_size2) { | |
auto vec_size = jit_rotary_kernel<dnnl::impl::cpu::x64::avx2>::vec_size; | |
if (param.rotary_ndims % (vec_size * 2) != 0) | |
flag = false; | |
} | |
if (flag) | |
res = std::make_shared<jit_rotary_kernel<dnnl::impl::cpu::x64::avx2>>(param); | |
} | |
if (res) | |
res->create_kernel(); | |
#endif // OPENVINO_ARCH_X86_64 | |
return res; | |
} | |
static void execJitKernel(const std::shared_ptr<kernel::JitKernelBase>& ker, const void* src, void* dst, const float* cos, const float* sin) { | |
MAYBE_UNUSED(ker); | |
MAYBE_UNUSED(src); | |
MAYBE_UNUSED(dst); | |
MAYBE_UNUSED(cos); | |
MAYBE_UNUSED(sin); | |
#if defined(OPENVINO_ARCH_X86_64) | |
jit_rotary_call_args call_args; | |
call_args.src = src; | |
call_args.cos = cos; | |
call_args.sin = sin; | |
call_args.dst = dst; | |
(*ker)(&call_args); | |
#endif // OPENVINO_ARCH_X86_64 | |
} | |
template <typename T> | |
struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor { | |
const op::internal::RoPE::Config& m_config; | |
std::shared_ptr<kernel::JitKernelBase> m_rotaryKernel; | |
RoPEExecutorRotateHalf(const op::internal::RoPE::Config& config) : m_config(config) { | |
jit_rotary_compile_params jcp; | |
jcp.src_prc = precision_of<T>::value; | |
jcp.dst_prc = precision_of<T>::value; | |
jcp.rotary_ndims = config.rotary_ndims; | |
jcp.interleave = false; | |
m_rotaryKernel = createJitKernel(jcp); | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have already written and tested my implementation, besides, the code you've sent me probably cannot be reused without modifications or bulky instantiations.
|
||
template<class CT> | ||
inline static void rotate_kv_cache_block_hw(CT* cache_block_ptr, float* block_rotation_coefficients_ptr, size_t num_heads, size_t block_size, size_t embedding_size) { | ||
#if !defined(HAVE_AVX2) && !defined(HAVE_AVX512F) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be cleaner if rotate_kv_cache_block_hw
and rotate_kv_cache_block_sw
are merged and let the rotate_kv_cache_chunk_xxx
to handle the tails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need HW and SW available as separate functions for testing purposes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test style CPU plugin used is to use layer/subgraph(sdpa sample) test to cover, due to CI infrastructure can cover avx2/avx512/amx, so there is no need to create a new test structure, we'd better to create a subgraph test just like sdpa.
But PagedAttention node is a little special, it's no reference now which means no reference result to compare, I think the reference will be coming.
Base on current status, I suggest the new test module may be removed and file a ticket to record the following up.
Correct me if I'm wrong. @dmitry-gorokhov @slyalin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally don't mind splitting vector and scalar impls into separate functions. I would say it is even better from code readability standpoint.
Regarding unit tests - that is actaully good developer practice. The fact we haven't implemented unit tests for intrinsics optimizations does't mean we should prohibit it at all. That seems to be convinient for developer purposes, so I am happy to have such an infrastructure.
src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp
Outdated
Show resolved
Hide resolved
afa851c
to
8a355f3
Compare
8a355f3
to
a33f255
Compare
138de47
to
a0818c2
Compare
...mon/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp
Show resolved
Hide resolved
7777b4e
to
86d281f
Compare
template <class T> | ||
void test_chunk_rotation_for_type() { | ||
auto instruction_set = std::get<0>(GetParam()); | ||
if (instruction_set == TargetInstructionSet::AVX512 && (!ov::with_cpu_x86_avx512f())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need the same condition for avx2. We still officially support CPUs with SSE isa only (even though we don't have them in pre-commit) - need to keep the tests green there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -143,20 +143,34 @@ inline void mm512_uni_storeu_tail_ps(ov::float16* addr, __m512 v, size_t count) | |||
} | |||
#endif | |||
|
|||
#if defined(HAVE_AVX2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Tickets:
153783