Skip to content

Commit

Permalink
Add stride constraint to XNN MaxPool
Browse files Browse the repository at this point in the history
Summary:
Add an XNNPACK partitioner constraint for MaxPool2d to enforce stride <= kernel_size.

See https://github.com/google/XNNPACK/blob/860f2b9ad9d3602599aff49a41d0131d2a350e00/src/subgraph/max-pooling-2d.c#L327.

Differential Revision: D67380978
  • Loading branch information
GregoryComer authored and facebook-github-bot committed Dec 18, 2024
1 parent f28e9a5 commit 5dc0b4f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
12 changes: 10 additions & 2 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,23 @@ class MaxPool2dConfig(GenericNodePartitionerConfig):

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
XNNPACK's maxpool2d does not support ceil mode
XNNPACK's maxpool2d does not support ceil mode and requires stride <= kernel_size
"""
if not self.check_common_constraints(node, ep):
return False


kernel_size = node.args[1]
stride = node.args[2]
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])

if is_ceil_mode:
why(node, reason="ceil mode is not supported")
return False

if stride[0] > kernel_size[0] or stride[1] > kernel_size[1]:
why(node, reason="stride must be less than or equal to kernel size")
return False

return True

def supported_precision_types(self) -> List[ConfigPrecisionType]:
Expand Down
22 changes: 22 additions & 0 deletions backends/xnnpack/test/ops/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,28 @@ def test_fp32_maxpool2d_unsupported_ceilmode(self):
.serialize()
.run_method_and_compare_outputs()
)

def test_fp32_maxpool2d_unsupported_stride(self):
"""
XNNPACK MaxPool2d requires stride <= kernel_size.
"""
inputs = (torch.randn(1, 32, 23, 23),)
(
Tester(self.MaxPool2d(kernel_size=2, stride=3), inputs)
.export()
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.to_edge_transform_and_lower()
# We expect it not be be delegated.
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
.check_count(
{
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
}
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_qs8_maxpool2d(self):
class MaxPool(torch.nn.Module):
Expand Down

0 comments on commit 5dc0b4f

Please sign in to comment.