Skip to content

Commit

Permalink
gtests: graph: unit: add gtests for verifying fill_local_in_map
Browse files Browse the repository at this point in the history
  • Loading branch information
ElaineBao authored and TaoLv committed Dec 25, 2024
1 parent 880bf7d commit 22fdb3a
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tests/gtests/graph/unit/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,14 @@ static inline std::vector<int64_t> compute_dense_strides(
}

static inline std::vector<dnnl::impl::graph::logical_tensor_t>
create_logical_tensors(
size_t num_lt, impl::data_type_t dtype = impl::data_type::f32) {
create_logical_tensors(size_t num_lt,
impl::data_type_t dtype = impl::data_type::f32,
size_t id_start_from = 0) {
size_t count = 0;
std::vector<dnnl::impl::graph::logical_tensor_t> lt_vec;
lt_vec.reserve(num_lt);
while (count < num_lt) {
lt_vec.emplace_back(logical_tensor_init(count, dtype));
lt_vec.emplace_back(logical_tensor_init(id_start_from + count, dtype));
count++;
}
return lt_vec;
Expand Down
108 changes: 108 additions & 0 deletions tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,114 @@ TEST(test_utils_pattern_matcher, SharedInput) {
ASSERT_EQ(fusion_ops2.size(), 0U);
}

TEST(test_utils_pattern_matcher, SharedInputCase2) {
/* Pattern that captures shared input to two MatMuls
Dequant Dequant
IN0 | IN0 |
\ / \ /
MatMul MatMul
\ /
Multiply
|
*/
auto graphp = std::make_shared<pb_graph_t>();
auto pdequant1 = graphp->append_op(DynamicDequantize);
auto pmm1 = graphp->append_op(MatMul, {in_edge(1, pdequant1, 0)});
auto pdequant2 = graphp->append_op(DynamicDequantize);
auto pmm2 = graphp->append_op(MatMul, {in_edge(1, pdequant2, 0)});
graphp->create_input_port(0, pmm1, 0);
graphp->create_input_port(0, pmm2, 0);
auto pmul = graphp->append_op(
Multiply, {in_edge(0, pmm1, 0), in_edge(1, pmm2, 0)});
UNUSED(pmul);

// test with a graph that has the shared input
graph_t agraph;
op_t dequant1 {0, DynamicDequantize, "dequant1"};
op_t matmul1 {1, MatMul, "matmul1"};
op_t dequant2 {2, DynamicDequantize, "dequant2"};
op_t matmul2 {3, MatMul, "matmul2"};
op_t multiply {4, Multiply, "multiply"};

std::vector<logical_tensor_t> lt_vec = create_logical_tensors(8);
std::vector<logical_tensor_t> lt_vec_s8
= create_logical_tensors(4, data_type::s8, 8);
dequant1.add_input(lt_vec_s8[0]);
dequant1.add_input(lt_vec[0]);
dequant1.add_input(lt_vec_s8[1]);
dequant1.add_output(lt_vec[1]);

matmul1.add_input(lt_vec[2]);
matmul1.add_input(lt_vec[1]);
matmul1.add_output(lt_vec[3]);

dequant2.add_input(lt_vec_s8[2]);
dequant2.add_input(lt_vec[4]);
dequant2.add_input(lt_vec_s8[3]);
dequant2.add_output(lt_vec[5]);

matmul2.add_input(lt_vec[2]);
matmul2.add_input(lt_vec[5]);
matmul2.add_output(lt_vec[6]);

multiply.add_input(lt_vec[3]);
multiply.add_input(lt_vec[6]);
multiply.add_output(lt_vec[7]);

ASSERT_EQ(agraph.add_op(&dequant1), status::success);
ASSERT_EQ(agraph.add_op(&matmul1), status::success);
ASSERT_EQ(agraph.add_op(&dequant2), status::success);
ASSERT_EQ(agraph.add_op(&matmul2), status::success);
ASSERT_EQ(agraph.add_op(&multiply), status::success);
ASSERT_EQ(agraph.finalize(), status::success);

std::vector<op_t *> fusion_ops;
EXPECT_TRUE(match_pattern(agraph.get_ops()[0].get(), graphp, fusion_ops));
ASSERT_EQ(fusion_ops.size(), 5U);

// test with a graph that does not have the shared input
graph_t agraph2;
op_t dequant3 {0, DynamicDequantize, "dequant1"};
op_t matmul3 {1, MatMul, "matmul1"};
op_t dequant4 {2, DynamicDequantize, "dequant2"};
op_t matmul4 {3, MatMul, "matmul2"};
op_t multiply2 {4, Multiply, "multiply"};

std::vector<logical_tensor_t> lt_vec2 = create_logical_tensors(9);
std::vector<logical_tensor_t> lt_vec2_s8
= create_logical_tensors(4, data_type::s8, 9);
dequant3.add_input(lt_vec2_s8[0]);
dequant3.add_input(lt_vec2[0]);
dequant3.add_input(lt_vec2_s8[1]);
dequant3.add_output(lt_vec2[1]);
matmul3.add_input(lt_vec2[2]);
matmul3.add_input(lt_vec2[1]);
matmul3.add_output(lt_vec2[3]);
dequant4.add_input(lt_vec2_s8[2]);
dequant4.add_input(lt_vec2[4]);
dequant4.add_input(lt_vec2_s8[3]);
dequant4.add_output(lt_vec2[5]);
matmul4.add_input(lt_vec2[8]);
matmul4.add_input(lt_vec2[5]);
matmul4.add_output(lt_vec2[6]);
multiply2.add_input(lt_vec2[3]);
multiply2.add_input(lt_vec2[6]);
multiply2.add_output(lt_vec2[7]);

ASSERT_EQ(agraph2.add_op(&dequant3), status::success);
ASSERT_EQ(agraph2.add_op(&matmul3), status::success);
ASSERT_EQ(agraph2.add_op(&dequant4), status::success);
ASSERT_EQ(agraph2.add_op(&matmul4), status::success);
ASSERT_EQ(agraph2.add_op(&multiply2), status::success);
agraph2.finalize();

std::vector<op_t *> fusion_ops2;
EXPECT_FALSE(
match_pattern(agraph2.get_ops()[0].get(), graphp, fusion_ops2));
ASSERT_EQ(fusion_ops2.size(), 0U);
}

TEST(test_utils_pattern_matcher, ParallelMatmul) {
auto graphp = std::make_shared<pb_graph_t>();
// Pattern that captures shared input to three MatMuls
Expand Down

0 comments on commit 22fdb3a

Please sign in to comment.