diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp index 1cc9be37606950..397746c75bb84d 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp @@ -61,16 +61,19 @@ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output& p auto p_opt_convert = optional(p_max_context_len); auto p_opt_reshape = optional({p_opt_convert, any_input()}); - // current seg len - auto p_input_ids = wrap_type(); - auto p_unsqueeze = wrap_type({p_input_ids, _const()}); - auto p_shape_of = wrap_type({p_unsqueeze}); + // current seq len: + // it might be present in 2 different ways: + // input_ids -> unsqueeze -> reshape -> convert -> shape_of -> gather + // QKV -> variadic_split(Q or K) -> rope Q/K -> shape_of -> gather + // Probably we can use the symbols to re-use one of these ways. + // Currently, "any_input" is used to detect the both places. + auto p_shape_of = wrap_type({any_input()}); auto p_current_len = wrap_type({p_shape_of, _const(), _const()}); - auto p_rotary_emb_sincos = wrap_type(); auto p_neg_const = wrap_type(); auto p_neg_mul = wrap_type({p_current_len, p_neg_const}); // the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128] + auto p_rotary_emb_sincos = wrap_type(); auto p_slice_1 = wrap_type({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()}); auto p_slice_2 = wrap_type({p_slice_1, p_neg_mul, _const(), _const(), _const()}); diff --git a/tests/constraints.txt b/tests/constraints.txt index 4f46cd0cc8b2e9..c339ac3c65d56f 100644 --- a/tests/constraints.txt +++ b/tests/constraints.txt @@ -21,11 +21,8 @@ pytest>=5.0,<8.4 pytest-dependency==0.5.1 pytest-html==4.1.1 pytest-timeout==2.3.1 -jax<=0.4.36 -jaxlib<=0.4.36 kornia==0.7.0 networkx<=3.3 -flax<=0.10.2 --extra-index-url https://download.pytorch.org/whl/cpu torch~=2.5.1; platform_system != "Darwin" or platform_machine != "x86_64" diff --git a/tests/layer_tests/requirements.txt b/tests/layer_tests/requirements.txt index 04889ebce10a39..2ba12cc5e2bece 100644 --- a/tests/layer_tests/requirements.txt +++ b/tests/layer_tests/requirements.txt @@ -16,5 +16,3 @@ pytest defusedxml tensorflow tensorflow-addons; python_version <= '3.10' -jax; sys_platform == "linux" and platform_machine == "x86_64" # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-cpu - wheels are for "x86_64" only -jaxlib; sys_platform == "linux" and platform_machine == "x86_64" # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-cpu - wheels are for "x86_64" only