Skip to content

Commit

Permalink
Add stride constraint to XNN MaxPool (#7354)
Browse files Browse the repository at this point in the history
Summary:
Add an XNNPACK partitioner constraint for MaxPool2d to enforce stride <= kernel_size.

See https://github.com/google/XNNPACK/blob/860f2b9ad9d3602599aff49a41d0131d2a350e00/src/subgraph/max-pooling-2d.c#L327.


Reviewed By: digantdesai

Differential Revision: D67380978

Pulled By: GregoryComer
  • Loading branch information
GregoryComer authored and facebook-github-bot committed Dec 28, 2024
1 parent fc04436 commit 18557ac
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 11 deletions.
27 changes: 16 additions & 11 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
format_target_name,
)
from executorch.exir.backend.utils import WhyNoPartition
from executorch.exir.backend.utils import is_shape_dynamic, WhyNoPartition
from torch.export import ExportedProgram

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -284,19 +284,27 @@ class MaxPool2dConfig(GenericNodePartitionerConfig):

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
XNNPACK's maxpool2d does not support ceil mode
XNNPACK's maxpool2d does not support ceil mode and requires stride <= kernel_size
"""
if not self.check_common_constraints(node, ep):
return False

# Ceil mode is supported via op padding, which must be statically known.
kernel_size = node.args[1]
stride = node.args[2]
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
is_dynamic = "val" in node.meta and any(
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
)
if is_ceil_mode and is_dynamic:

# Ceil mode is supported via op padding, which must be statically known.
if is_ceil_mode and is_shape_dynamic(node):
why(node, reason="ceil mode is not supported for dynamic shapes")
return False

if stride[0] > kernel_size[0] or stride[1] > kernel_size[1]: # pyre-ignore[16]
why(
node,
reason=f"stride ({stride}) must be less than or equal to kernel size ({kernel_size})",
)
return False

return True

def supported_precision_types(self) -> List[ConfigPrecisionType]:
Expand All @@ -316,10 +324,7 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
if not self.check_common_constraints(node, ep):
return False

is_output_dynamic = "val" in node.meta and any(
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
)
if is_output_dynamic:
if is_shape_dynamic(node):
why(node, reason="dynamic output sizes are not supported")
return False
return True
Expand Down
22 changes: 22 additions & 0 deletions backends/xnnpack/test/ops/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,28 @@ def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self):
.run_method_and_compare_outputs()
)

def test_fp32_maxpool2d_unsupported_stride(self):
"""
XNNPACK MaxPool2d requires stride <= kernel_size.
"""
inputs = (torch.randn(1, 32, 23, 23),)
(
Tester(self.MaxPool2d(kernel_size=2, stride=3), inputs)
.export()
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.to_edge_transform_and_lower()
# We expect it not be be delegated.
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
.check_count(
{
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
}
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_qs8_maxpool2d(self):
class MaxPool(torch.nn.Module):
def __init__(self, maxpool_params):
Expand Down
13 changes: 13 additions & 0 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
import operator
from collections import defaultdict
Expand Down Expand Up @@ -417,6 +419,17 @@ def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
node.meta["delegation_tag"] = user_tags.pop()


def is_shape_dynamic(node: torch.fx.Node) -> bool:
"""
Check if the node shape is dynamic.
"""

# Shape is dynamic if any of the dimensions don't evaluate to a static value
return "val" in node.meta and any(
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
)


# TODO - style: use templated types
class DelegateMappingBuilder:
"""
Expand Down

0 comments on commit 18557ac

Please sign in to comment.