From 671b04c2c62efd0a7396481bb9522cd8a43034d0 Mon Sep 17 00:00:00 2001 From: Umar Arshad Date: Wed, 25 Dec 2024 23:19:06 -0800 Subject: [PATCH] graph: sdpa: Fix k/v only pattern to allow for optional mask/scale --- src/graph/backend/dnnl/patterns/sdp.cpp | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/graph/backend/dnnl/patterns/sdp.cpp b/src/graph/backend/dnnl/patterns/sdp.cpp index f49b830d173..781f270b59d 100644 --- a/src/graph/backend/dnnl/patterns/sdp.cpp +++ b/src/graph/backend/dnnl/patterns/sdp.cpp @@ -436,9 +436,13 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, sdp_with_compressed_v_fusion) [](const std::shared_ptr &pgraph) -> void { auto matmul_qk = pgraph->append_op(graph::op_kind::MatMul); - auto fscore_scale = pgraph->append_alternation( - {graph::op_kind::Divide, graph::op_kind::Multiply}, - {in_edge(0, matmul_qk, 0)}); + auto scale_graph = std::make_shared(); + auto scale = scale_graph->append_alternation( + {graph::op_kind::Divide, graph::op_kind::Multiply}); + scale_graph->create_input_port(0, scale, 0); + scale_graph->create_output_port(0, scale, 0); + auto optional_scale = pgraph->append_optional( + scale_graph, {in_edge(0, matmul_qk, 0)}); auto optional_mask = std::make_shared(); auto fscore_add @@ -446,7 +450,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, sdp_with_compressed_v_fusion) optional_mask->create_input_port(0, fscore_add, 0); optional_mask->create_output_port(0, fscore_add, 0); auto mask = pgraph->append_optional( - optional_mask, {in_edge(0, fscore_scale, 0)}); + optional_mask, {in_edge(0, optional_scale, 0)}); // Optional select for distilbert auto p_select2 = optional_select(pgraph, mask, 2); @@ -474,9 +478,13 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, sdp_with_compressed_k_fusion) auto matmul_qk = pgraph->append_op(graph::op_kind::MatMul, {in_edge(1, dequantize_key, 0)}); - auto fscore_scale = pgraph->append_alternation( - {graph::op_kind::Divide, graph::op_kind::Multiply}, - {in_edge(0, matmul_qk, 0)}); + auto scale_graph = std::make_shared(); + auto scale = scale_graph->append_alternation( + {graph::op_kind::Divide, graph::op_kind::Multiply}); + scale_graph->create_input_port(0, scale, 0); + scale_graph->create_output_port(0, scale, 0); + auto optional_scale = pgraph->append_optional( + scale_graph, {in_edge(0, matmul_qk, 0)}); auto optional_mask = std::make_shared(); auto fscore_add @@ -484,7 +492,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, sdp_with_compressed_k_fusion) optional_mask->create_input_port(0, fscore_add, 0); optional_mask->create_output_port(0, fscore_add, 0); auto mask = pgraph->append_optional( - optional_mask, {in_edge(0, fscore_scale, 0)}); + optional_mask, {in_edge(0, optional_scale, 0)}); // Optional select for distilbert auto p_select2 = optional_select(pgraph, mask, 2);