From 18557ac7e4b300ab18d7b617005344546e994afa Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Fri, 27 Dec 2024 19:09:27 -0800 Subject: [PATCH] Add stride constraint to XNN MaxPool (#7354) Summary: Add an XNNPACK partitioner constraint for MaxPool2d to enforce stride <= kernel_size. See https://github.com/google/XNNPACK/blob/860f2b9ad9d3602599aff49a41d0131d2a350e00/src/subgraph/max-pooling-2d.c#L327. Reviewed By: digantdesai Differential Revision: D67380978 Pulled By: GregoryComer --- .../partition/config/generic_node_configs.py | 27 +++++++++++-------- backends/xnnpack/test/ops/test_maxpool2d.py | 22 +++++++++++++++ exir/backend/utils.py | 13 +++++++++ 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index c97f27700e..bdb1f3802b 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -19,7 +19,7 @@ from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) -from executorch.exir.backend.utils import WhyNoPartition +from executorch.exir.backend.utils import is_shape_dynamic, WhyNoPartition from torch.export import ExportedProgram logger = logging.getLogger(__name__) @@ -284,19 +284,27 @@ class MaxPool2dConfig(GenericNodePartitionerConfig): def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: """ - XNNPACK's maxpool2d does not support ceil mode + XNNPACK's maxpool2d does not support ceil mode and requires stride <= kernel_size """ if not self.check_common_constraints(node, ep): return False - # Ceil mode is supported via op padding, which must be statically known. + kernel_size = node.args[1] + stride = node.args[2] is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5]) - is_dynamic = "val" in node.meta and any( - isinstance(d, torch.SymInt) for d in node.meta["val"].shape - ) - if is_ceil_mode and is_dynamic: + + # Ceil mode is supported via op padding, which must be statically known. + if is_ceil_mode and is_shape_dynamic(node): why(node, reason="ceil mode is not supported for dynamic shapes") return False + + if stride[0] > kernel_size[0] or stride[1] > kernel_size[1]: # pyre-ignore[16] + why( + node, + reason=f"stride ({stride}) must be less than or equal to kernel size ({kernel_size})", + ) + return False + return True def supported_precision_types(self) -> List[ConfigPrecisionType]: @@ -316,10 +324,7 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: if not self.check_common_constraints(node, ep): return False - is_output_dynamic = "val" in node.meta and any( - isinstance(d, torch.SymInt) for d in node.meta["val"].shape - ) - if is_output_dynamic: + if is_shape_dynamic(node): why(node, reason="dynamic output sizes are not supported") return False return True diff --git a/backends/xnnpack/test/ops/test_maxpool2d.py b/backends/xnnpack/test/ops/test_maxpool2d.py index 4247fa1a46..521235232a 100644 --- a/backends/xnnpack/test/ops/test_maxpool2d.py +++ b/backends/xnnpack/test/ops/test_maxpool2d.py @@ -163,6 +163,28 @@ def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self): .run_method_and_compare_outputs() ) + def test_fp32_maxpool2d_unsupported_stride(self): + """ + XNNPACK MaxPool2d requires stride <= kernel_size. + """ + inputs = (torch.randn(1, 32, 23, 23),) + ( + Tester(self.MaxPool2d(kernel_size=2, stride=3), inputs) + .export() + .check_count({"torch.ops.aten.max_pool2d.default": 1}) + .to_edge_transform_and_lower() + # We expect it not be be delegated. + .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) + .check_count( + { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1 + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + def test_qs8_maxpool2d(self): class MaxPool(torch.nn.Module): def __init__(self, maxpool_params): diff --git a/exir/backend/utils.py b/exir/backend/utils.py index fb5e16c6bd..50d1e73fd7 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import logging import operator from collections import defaultdict @@ -417,6 +419,17 @@ def tag_mutated_buffer(edge_program: ExportedProgram) -> None: node.meta["delegation_tag"] = user_tags.pop() +def is_shape_dynamic(node: torch.fx.Node) -> bool: + """ + Check if the node shape is dynamic. + """ + + # Shape is dynamic if any of the dimensions don't evaluate to a static value + return "val" in node.meta and any( + isinstance(d, torch.SymInt) for d in node.meta["val"].shape + ) + + # TODO - style: use templated types class DelegateMappingBuilder: """