Skip to content

Commit

Permalink
Merge branch 'master' into hunyuan_model
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaohb authored Dec 26, 2024
2 parents b53a89a + df75d0c commit 9e3ba96
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,19 @@ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output<Node>& p
auto p_opt_convert = optional<v0::Convert>(p_max_context_len);
auto p_opt_reshape = optional<v1::Reshape>({p_opt_convert, any_input()});

// current seg len
auto p_input_ids = wrap_type<v0::Parameter>();
auto p_unsqueeze = wrap_type<v0::Unsqueeze>({p_input_ids, _const()});
auto p_shape_of = wrap_type<v3::ShapeOf>({p_unsqueeze});
// current seq len:
// it might be present in 2 different ways:
// input_ids -> unsqueeze -> reshape -> convert -> shape_of -> gather
// QKV -> variadic_split(Q or K) -> rope Q/K -> shape_of -> gather
// Probably we can use the symbols to re-use one of these ways.
// Currently, "any_input" is used to detect the both places.
auto p_shape_of = wrap_type<v3::ShapeOf>({any_input()});
auto p_current_len = wrap_type<v8::Gather>({p_shape_of, _const(), _const()});

auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
auto p_neg_const = wrap_type<v0::Constant>();
auto p_neg_mul = wrap_type<v1::Multiply>({p_current_len, p_neg_const});
// the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128]
auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
auto p_slice_1 = wrap_type<v8::Slice>({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()});
auto p_slice_2 = wrap_type<v8::Slice>({p_slice_1, p_neg_mul, _const(), _const(), _const()});

Expand Down
3 changes: 0 additions & 3 deletions tests/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ pytest>=5.0,<8.4
pytest-dependency==0.5.1
pytest-html==4.1.1
pytest-timeout==2.3.1
jax<=0.4.36
jaxlib<=0.4.36
kornia==0.7.0
networkx<=3.3
flax<=0.10.2

--extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.5.1; platform_system != "Darwin" or platform_machine != "x86_64"
Expand Down
2 changes: 0 additions & 2 deletions tests/layer_tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,3 @@ pytest
defusedxml
tensorflow
tensorflow-addons; python_version <= '3.10'
jax; sys_platform == "linux" and platform_machine == "x86_64" # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-cpu - wheels are for "x86_64" only
jaxlib; sys_platform == "linux" and platform_machine == "x86_64" # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-cpu - wheels are for "x86_64" only

0 comments on commit 9e3ba96

Please sign in to comment.