From 713b5e19c0ea1b59cae5d3089bf7bd8243aee700 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Thu, 19 Dec 2024 00:57:37 -0800 Subject: [PATCH] Support 5-input concat in XNNPACK delegate Summary: I noticed that support for 5-input concatenate ops was added to the XNNPACK library subgraph layer. We can support this in the delegate. Differential Revision: D67439458 --- backends/xnnpack/operators/op_cat.py | 15 +++++ .../partition/config/generic_node_configs.py | 6 +- backends/xnnpack/runtime/XNNCompiler.cpp | 40 ++++++++++++- .../xnnpack/serialization/runtime_schema.fbs | 2 + backends/xnnpack/serialization/schema.fbs | 2 + .../serialization/xnnpack_graph_schema.py | 6 ++ backends/xnnpack/test/ops/test_cat.py | 60 +++++++++---------- 7 files changed, 93 insertions(+), 38 deletions(-) diff --git a/backends/xnnpack/operators/op_cat.py b/backends/xnnpack/operators/op_cat.py index 706073ef9b8..d2181c210e5 100644 --- a/backends/xnnpack/operators/op_cat.py +++ b/backends/xnnpack/operators/op_cat.py @@ -17,6 +17,7 @@ XNNConcatenate2, XNNConcatenate3, XNNConcatenate4, + XNNConcatenate5, XNNGraph, XNode, ) @@ -71,6 +72,7 @@ def define_node( input2_id=vals_to_ids[list_of_tensors[1]], input3_id=XNN_INVALID_VALUE_ID, input4_id=XNN_INVALID_VALUE_ID, + input5_id=XNN_INVALID_VALUE_ID, output_id=vals_to_ids[node], flags=0, ) @@ -81,6 +83,7 @@ def define_node( input2_id=vals_to_ids[list_of_tensors[1]], input3_id=vals_to_ids[list_of_tensors[2]], input4_id=XNN_INVALID_VALUE_ID, + input5_id=XNN_INVALID_VALUE_ID, output_id=vals_to_ids[node], flags=0, ) @@ -91,6 +94,18 @@ def define_node( input2_id=vals_to_ids[list_of_tensors[1]], input3_id=vals_to_ids[list_of_tensors[2]], input4_id=vals_to_ids[list_of_tensors[3]], + input5_id=XNN_INVALID_VALUE_ID, + output_id=vals_to_ids[node], + flags=0, + ) + elif num_tensors_to_cat == 5: + xnode = XNNConcatenate5( + axis=axis, + input1_id=vals_to_ids[list_of_tensors[0]], + input2_id=vals_to_ids[list_of_tensors[1]], + input3_id=vals_to_ids[list_of_tensors[2]], + input4_id=vals_to_ids[list_of_tensors[3]], + input5_id=vals_to_ids[list_of_tensors[4]], output_id=vals_to_ids[node], flags=0, ) diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index f08b8ccb3c0..83a87afd3d0 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -172,17 +172,17 @@ class CatConfig(GenericNodePartitionerConfig): def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: """ - Only support concatenation of 2 - 4 tensors + Only support concatenation of 2 - 5 tensors """ if not self.check_common_constraints(node, ep): return False num_tensors = len(node.all_input_nodes) - if not (num_tensors >= 2 and num_tensors <= 4): + if not (num_tensors >= 2 and num_tensors <= 5): why( node, - reason=f"only support concatenation of 2 - 4 tensors, got {num_tensors} tensors", + reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors", ) return False diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index b948aa8623d..3d4d2e68219 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1600,7 +1600,7 @@ Error defineConcatenate2Node( } /* -Defines serialized concatenate2 node into the subgraph, +Defines serialized concatenate3 node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the tensor value */ @@ -1633,7 +1633,7 @@ Error defineConcatenate3Node( } /* -Defines serialized concatenate2 node into the subgraph, +Defines serialized concatenate4 node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the tensor value */ @@ -1666,6 +1666,41 @@ Error defineConcatenate4Node( return Error::Ok; } +/* +Defines serialized concatenate5 node into the subgraph, +using the remapped ids to map the serialized ids, +to the new ids generated when defining the tensor value +*/ +Error defineConcatenate5Node( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNConcatenate5(); + + xnn_status status = xnn_define_concatenate5( + subgraph_ptr, + graph_node->axis(), + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), + remapped_ids.at(graph_node->input3_id()), + remapped_ids.at(graph_node->input4_id()), + remapped_ids.at(graph_node->input5_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create cat5 node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Defines serialized static_slice node into the subgraph, using the remapped ids to map the serialized ids, @@ -1832,6 +1867,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Concatenate2) _DEFINE(Concatenate3) _DEFINE(Concatenate4) + _DEFINE(Concatenate5) _DEFINE(StaticSlice) _DEFINE(ScaledDotProductAttention) _DEFINE(BatchMatrixMultiply) diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index efe717e085e..08cb00911ab 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -133,6 +133,7 @@ union XNodeUnion { XNNConcatenate2: _XNNCat, XNNConcatenate3: _XNNCat, XNNConcatenate4: _XNNCat, + XNNConcatenate5: _XNNCat, XNNStaticSlice, XNNScaledDotProductAttention, XNNBatchMatrixMultiply: _XNNNode2x1, @@ -209,6 +210,7 @@ table _XNNCat { input4_id: uint; output_id: uint; flags: uint; + input5_id: uint; } table XNNELU { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 33571195d63..6b0d6509fdb 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -129,6 +129,7 @@ union XNodeUnion { XNNConcatenate2: _XNNCat, XNNConcatenate3: _XNNCat, XNNConcatenate4: _XNNCat, + XNNConcatenate5: _XNNCat, XNNStaticSlice, XNNScaledDotProductAttention, XNNBatchMatrixMultiply: _XNNNode2x1, @@ -205,6 +206,7 @@ table _XNNCat { input4_id: uint; output_id: uint; flags: uint; + input5_id: uint; } table XNNELU { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index e3e699c58f8..7c3769d3ca7 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -42,6 +42,7 @@ class XNNCat: input4_id: int output_id: int flags: int + input5_id: int # Generic node data class for convolution type nodes @@ -177,6 +178,11 @@ class XNNConcatenate4(XNNCat): pass +@dataclass +class XNNConcatenate5(XNNCat): + pass + + @dataclass class XNNBatchMatrixMultiply(XNNNode2x1): pass diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index 039da2c0755..fdfe119d8a9 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -11,27 +11,9 @@ class TestCat(unittest.TestCase): - class Cat2(torch.nn.Module): - def forward(self, arg1, arg2): - xs = [arg1, arg2] - x = torch.cat(xs) - return x + x # Quantize by propagation. - - class Cat3(torch.nn.Module): - def forward(self, arg1, arg2, arg3): - xs = [arg1, arg2, arg3] - x = torch.cat(xs) - return x + x # Quantize by propagation. - - class Cat4(torch.nn.Module): - def forward(self, arg1, arg2, arg3, arg4): - xs = [arg1, arg2, arg3, arg4] - x = torch.cat(xs) - return x + x # Quantize by propagation. - - class Cat5(torch.nn.Module): - def forward(self, arg1, arg2, arg3, arg4, arg5): - xs = [arg1, arg2, arg3, arg4, arg5] + class Cat(torch.nn.Module): + def forward(self, *args): + xs = [*args] x = torch.cat(xs) return x + x # Quantize by propagation. @@ -84,7 +66,7 @@ def test_fp16_cat2(self): torch.randn(1, 2, 3).to(torch.float16), torch.randn(3, 2, 3).to(torch.float16), ) - self._test_cat(self.Cat2(), inputs) + self._test_cat(self.Cat(), inputs) def test_fp16_cat3(self): """ @@ -95,7 +77,7 @@ def test_fp16_cat3(self): torch.randn(3, 2, 3).to(torch.float16), torch.randn(2, 2, 3).to(torch.float16), ) - self._test_cat(self.Cat3(), inputs) + self._test_cat(self.Cat(), inputs) def test_fp16_cat4(self): """ @@ -107,15 +89,15 @@ def test_fp16_cat4(self): torch.randn(2, 2, 3).to(torch.float16), torch.randn(5, 2, 3).to(torch.float16), ) - self._test_cat(self.Cat4(), inputs) + self._test_cat(self.Cat(), inputs) def test_fp32_cat2(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3)) - self._test_cat(self.Cat2(), inputs) + self._test_cat(self.Cat(), inputs) def test_fp32_cat3(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3)) - self._test_cat(self.Cat3(), inputs) + self._test_cat(self.Cat(), inputs) def test_fp32_cat4(self): inputs = ( @@ -124,15 +106,25 @@ def test_fp32_cat4(self): torch.randn(2, 2, 3), torch.randn(5, 2, 3), ) - self._test_cat(self.Cat4(), inputs) + self._test_cat(self.Cat(), inputs) + + def test_fp32_cat5(self): + inputs = ( + torch.randn(1, 2, 3), + torch.randn(3, 2, 3), + torch.randn(2, 2, 3), + torch.randn(5, 2, 3), + torch.randn(1, 2, 3), + ) + self._test_cat(self.Cat(), inputs) def test_qs8_cat2(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3)) - self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True) + self._test_cat(self.Cat(), inputs, cat_num=2, quant=True) def test_qs8_cat3(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3)) - self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True) + self._test_cat(self.Cat(), inputs, cat_num=3, quant=True) def test_qs8_cat4(self): inputs = ( @@ -141,7 +133,7 @@ def test_qs8_cat4(self): torch.randn(2, 2, 3), torch.randn(5, 2, 3), ) - self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True) + self._test_cat(self.Cat(), inputs, cat_num=4, quant=True) def test_fp32_cat_unsupported(self): """ @@ -153,9 +145,10 @@ def test_fp32_cat_unsupported(self): torch.randn(2, 2, 3), torch.randn(5, 2, 3), torch.randn(1, 2, 3), + torch.randn(2, 2, 3), ) ( - Tester(self.Cat5(), inputs) + Tester(self.Cat(), inputs) .export() .check_count({"torch.ops.aten.cat": 1}) .to_edge_transform_and_lower() @@ -164,7 +157,7 @@ def test_fp32_cat_unsupported(self): def test_fp32_cat_unsupported_legacy_mode(self): """ - XNNPACK only supports concatenating up to 4 values, so it should not delegate here. + XNNPACK only supports concatenating up to 5 values, so it should not delegate here. """ inputs = ( torch.randn(1, 2, 3), @@ -172,9 +165,10 @@ def test_fp32_cat_unsupported_legacy_mode(self): torch.randn(2, 2, 3), torch.randn(5, 2, 3), torch.randn(1, 2, 3), + torch.randn(6, 2, 3), ) ( - Tester(self.Cat5(), inputs) + Tester(self.Cat(), inputs) .export() .check_count({"torch.ops.aten.cat": 1}) .to_edge()