diff --git a/tests/gtests/graph/unit/utils.hpp b/tests/gtests/graph/unit/utils.hpp index 9feab434b38..bf25145a17f 100644 --- a/tests/gtests/graph/unit/utils.hpp +++ b/tests/gtests/graph/unit/utils.hpp @@ -144,13 +144,14 @@ static inline std::vector compute_dense_strides( } static inline std::vector -create_logical_tensors( - size_t num_lt, impl::data_type_t dtype = impl::data_type::f32) { +create_logical_tensors(size_t num_lt, + impl::data_type_t dtype = impl::data_type::f32, + size_t id_start_from = 0) { size_t count = 0; std::vector lt_vec; lt_vec.reserve(num_lt); while (count < num_lt) { - lt_vec.emplace_back(logical_tensor_init(count, dtype)); + lt_vec.emplace_back(logical_tensor_init(id_start_from + count, dtype)); count++; } return lt_vec; diff --git a/tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp b/tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp index ffafc22a390..575269e1cd2 100644 --- a/tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp +++ b/tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp @@ -1338,6 +1338,114 @@ TEST(test_utils_pattern_matcher, SharedInput) { ASSERT_EQ(fusion_ops2.size(), 0U); } +TEST(test_utils_pattern_matcher, SharedInputCase2) { + /* Pattern that captures shared input to two MatMuls + + Dequant Dequant + IN0 | IN0 | + \ / \ / + MatMul MatMul + \ / + Multiply + | + */ + auto graphp = std::make_shared(); + auto pdequant1 = graphp->append_op(DynamicDequantize); + auto pmm1 = graphp->append_op(MatMul, {in_edge(1, pdequant1, 0)}); + auto pdequant2 = graphp->append_op(DynamicDequantize); + auto pmm2 = graphp->append_op(MatMul, {in_edge(1, pdequant2, 0)}); + graphp->create_input_port(0, pmm1, 0); + graphp->create_input_port(0, pmm2, 0); + auto pmul = graphp->append_op( + Multiply, {in_edge(0, pmm1, 0), in_edge(1, pmm2, 0)}); + UNUSED(pmul); + + // test with a graph that has the shared input + graph_t agraph; + op_t dequant1 {0, DynamicDequantize, "dequant1"}; + op_t matmul1 {1, MatMul, "matmul1"}; + op_t dequant2 {2, DynamicDequantize, "dequant2"}; + op_t matmul2 {3, MatMul, "matmul2"}; + op_t multiply {4, Multiply, "multiply"}; + + std::vector lt_vec = create_logical_tensors(8); + std::vector lt_vec_s8 + = create_logical_tensors(4, data_type::s8, 8); + dequant1.add_input(lt_vec_s8[0]); + dequant1.add_input(lt_vec[0]); + dequant1.add_input(lt_vec_s8[1]); + dequant1.add_output(lt_vec[1]); + + matmul1.add_input(lt_vec[2]); + matmul1.add_input(lt_vec[1]); + matmul1.add_output(lt_vec[3]); + + dequant2.add_input(lt_vec_s8[2]); + dequant2.add_input(lt_vec[4]); + dequant2.add_input(lt_vec_s8[3]); + dequant2.add_output(lt_vec[5]); + + matmul2.add_input(lt_vec[2]); + matmul2.add_input(lt_vec[5]); + matmul2.add_output(lt_vec[6]); + + multiply.add_input(lt_vec[3]); + multiply.add_input(lt_vec[6]); + multiply.add_output(lt_vec[7]); + + ASSERT_EQ(agraph.add_op(&dequant1), status::success); + ASSERT_EQ(agraph.add_op(&matmul1), status::success); + ASSERT_EQ(agraph.add_op(&dequant2), status::success); + ASSERT_EQ(agraph.add_op(&matmul2), status::success); + ASSERT_EQ(agraph.add_op(&multiply), status::success); + ASSERT_EQ(agraph.finalize(), status::success); + + std::vector fusion_ops; + EXPECT_TRUE(match_pattern(agraph.get_ops()[0].get(), graphp, fusion_ops)); + ASSERT_EQ(fusion_ops.size(), 5U); + + // test with a graph that does not have the shared input + graph_t agraph2; + op_t dequant3 {0, DynamicDequantize, "dequant1"}; + op_t matmul3 {1, MatMul, "matmul1"}; + op_t dequant4 {2, DynamicDequantize, "dequant2"}; + op_t matmul4 {3, MatMul, "matmul2"}; + op_t multiply2 {4, Multiply, "multiply"}; + + std::vector lt_vec2 = create_logical_tensors(9); + std::vector lt_vec2_s8 + = create_logical_tensors(4, data_type::s8, 9); + dequant3.add_input(lt_vec2_s8[0]); + dequant3.add_input(lt_vec2[0]); + dequant3.add_input(lt_vec2_s8[1]); + dequant3.add_output(lt_vec2[1]); + matmul3.add_input(lt_vec2[2]); + matmul3.add_input(lt_vec2[1]); + matmul3.add_output(lt_vec2[3]); + dequant4.add_input(lt_vec2_s8[2]); + dequant4.add_input(lt_vec2[4]); + dequant4.add_input(lt_vec2_s8[3]); + dequant4.add_output(lt_vec2[5]); + matmul4.add_input(lt_vec2[8]); + matmul4.add_input(lt_vec2[5]); + matmul4.add_output(lt_vec2[6]); + multiply2.add_input(lt_vec2[3]); + multiply2.add_input(lt_vec2[6]); + multiply2.add_output(lt_vec2[7]); + + ASSERT_EQ(agraph2.add_op(&dequant3), status::success); + ASSERT_EQ(agraph2.add_op(&matmul3), status::success); + ASSERT_EQ(agraph2.add_op(&dequant4), status::success); + ASSERT_EQ(agraph2.add_op(&matmul4), status::success); + ASSERT_EQ(agraph2.add_op(&multiply2), status::success); + agraph2.finalize(); + + std::vector fusion_ops2; + EXPECT_FALSE( + match_pattern(agraph2.get_ops()[0].get(), graphp, fusion_ops2)); + ASSERT_EQ(fusion_ops2.size(), 0U); +} + TEST(test_utils_pattern_matcher, ParallelMatmul) { auto graphp = std::make_shared(); // Pattern that captures shared input to three MatMuls