Skip to content

Commit

Permalink
graph: sdpa: Fix k/v only pattern to allow for optional mask/scale
Browse files Browse the repository at this point in the history
  • Loading branch information
umar456 committed Dec 26, 2024
1 parent 22fdb3a commit 671b04c
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions src/graph/backend/dnnl/patterns/sdp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,17 +436,21 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, sdp_with_compressed_v_fusion)
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
auto matmul_qk = pgraph->append_op(graph::op_kind::MatMul);

auto fscore_scale = pgraph->append_alternation(
{graph::op_kind::Divide, graph::op_kind::Multiply},
{in_edge(0, matmul_qk, 0)});
auto scale_graph = std::make_shared<pb_graph_t>();
auto scale = scale_graph->append_alternation(
{graph::op_kind::Divide, graph::op_kind::Multiply});
scale_graph->create_input_port(0, scale, 0);
scale_graph->create_output_port(0, scale, 0);
auto optional_scale = pgraph->append_optional(
scale_graph, {in_edge(0, matmul_qk, 0)});

auto optional_mask = std::make_shared<pb_graph_t>();
auto fscore_add
= optional_mask->append_op(graph::op_kind::Add);
optional_mask->create_input_port(0, fscore_add, 0);
optional_mask->create_output_port(0, fscore_add, 0);
auto mask = pgraph->append_optional(
optional_mask, {in_edge(0, fscore_scale, 0)});
optional_mask, {in_edge(0, optional_scale, 0)});

// Optional select for distilbert
auto p_select2 = optional_select(pgraph, mask, 2);
Expand Down Expand Up @@ -474,17 +478,21 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, sdp_with_compressed_k_fusion)
auto matmul_qk = pgraph->append_op(graph::op_kind::MatMul,
{in_edge(1, dequantize_key, 0)});

auto fscore_scale = pgraph->append_alternation(
{graph::op_kind::Divide, graph::op_kind::Multiply},
{in_edge(0, matmul_qk, 0)});
auto scale_graph = std::make_shared<pb_graph_t>();
auto scale = scale_graph->append_alternation(
{graph::op_kind::Divide, graph::op_kind::Multiply});
scale_graph->create_input_port(0, scale, 0);
scale_graph->create_output_port(0, scale, 0);
auto optional_scale = pgraph->append_optional(
scale_graph, {in_edge(0, matmul_qk, 0)});

auto optional_mask = std::make_shared<pb_graph_t>();
auto fscore_add
= optional_mask->append_op(graph::op_kind::Add);
optional_mask->create_input_port(0, fscore_add, 0);
optional_mask->create_output_port(0, fscore_add, 0);
auto mask = pgraph->append_optional(
optional_mask, {in_edge(0, fscore_scale, 0)});
optional_mask, {in_edge(0, optional_scale, 0)});

// Optional select for distilbert
auto p_select2 = optional_select(pgraph, mask, 2);
Expand Down

0 comments on commit 671b04c

Please sign in to comment.