Skip to content

Commit

Permalink
Fix PositionIdsReplacer for Qwen (#28203)
Browse files Browse the repository at this point in the history
### Details:
The "current len" pattern might be created in 2 different ways in the
Qwen model.
After one of the comments in the review
(https://github.com/openvinotoolkit/openvino/pull/28067/files) was
resolved, we made the pattern stricter, it stopped covering one of the
cases.


### Tickets:
 - *CVS-157308*
  • Loading branch information
itikhono authored Dec 26, 2024
1 parent e8ae56e commit df75d0c
Showing 1 changed file with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,19 @@ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output<Node>& p
auto p_opt_convert = optional<v0::Convert>(p_max_context_len);
auto p_opt_reshape = optional<v1::Reshape>({p_opt_convert, any_input()});

// current seg len
auto p_input_ids = wrap_type<v0::Parameter>();
auto p_unsqueeze = wrap_type<v0::Unsqueeze>({p_input_ids, _const()});
auto p_shape_of = wrap_type<v3::ShapeOf>({p_unsqueeze});
// current seq len:
// it might be present in 2 different ways:
// input_ids -> unsqueeze -> reshape -> convert -> shape_of -> gather
// QKV -> variadic_split(Q or K) -> rope Q/K -> shape_of -> gather
// Probably we can use the symbols to re-use one of these ways.
// Currently, "any_input" is used to detect the both places.
auto p_shape_of = wrap_type<v3::ShapeOf>({any_input()});
auto p_current_len = wrap_type<v8::Gather>({p_shape_of, _const(), _const()});

auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
auto p_neg_const = wrap_type<v0::Constant>();
auto p_neg_mul = wrap_type<v1::Multiply>({p_current_len, p_neg_const});
// the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128]
auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
auto p_slice_1 = wrap_type<v8::Slice>({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()});
auto p_slice_2 = wrap_type<v8::Slice>({p_slice_1, p_neg_mul, _const(), _const(), _const()});

Expand Down

0 comments on commit df75d0c

Please sign in to comment.