From 93f2f269bff6d54e80f06c5943a6a0985349a3d0 Mon Sep 17 00:00:00 2001 From: Tikhonov Ivan Date: Wed, 25 Dec 2024 18:53:42 +0400 Subject: [PATCH] Fix PositionIdsReplacer for Qwen --- .../position_ids_replacer.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp index 1cc9be37606950..397746c75bb84d 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp @@ -61,16 +61,19 @@ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output& p auto p_opt_convert = optional(p_max_context_len); auto p_opt_reshape = optional({p_opt_convert, any_input()}); - // current seg len - auto p_input_ids = wrap_type(); - auto p_unsqueeze = wrap_type({p_input_ids, _const()}); - auto p_shape_of = wrap_type({p_unsqueeze}); + // current seq len: + // it might be present in 2 different ways: + // input_ids -> unsqueeze -> reshape -> convert -> shape_of -> gather + // QKV -> variadic_split(Q or K) -> rope Q/K -> shape_of -> gather + // Probably we can use the symbols to re-use one of these ways. + // Currently, "any_input" is used to detect the both places. + auto p_shape_of = wrap_type({any_input()}); auto p_current_len = wrap_type({p_shape_of, _const(), _const()}); - auto p_rotary_emb_sincos = wrap_type(); auto p_neg_const = wrap_type(); auto p_neg_mul = wrap_type({p_current_len, p_neg_const}); // the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128] + auto p_rotary_emb_sincos = wrap_type(); auto p_slice_1 = wrap_type({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()}); auto p_slice_2 = wrap_type({p_slice_1, p_neg_mul, _const(), _const(), _const()});