From fc50a1350d5d78671f7a3af20ef85393524ee8f1 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Sat, 21 Dec 2024 00:28:13 -0800 Subject: [PATCH] Add dim_order compat support (#7420) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/7420 Differential Revision: D67542995 --- backends/apple/mps/mps_preprocess.py | 6 ++++++ backends/apple/mps/operators/constant_ops.py | 16 ++++++++++++++++ backends/apple/mps/operators/op_clone.py | 16 ++++++++++++++++ backends/apple/mps/test/test_mps.py | 10 ++++++++++ backends/apple/mps/test/test_mps_utils.py | 2 +- 5 files changed, 49 insertions(+), 1 deletion(-) diff --git a/backends/apple/mps/mps_preprocess.py b/backends/apple/mps/mps_preprocess.py index 8362774fa9..749f32a04e 100644 --- a/backends/apple/mps/mps_preprocess.py +++ b/backends/apple/mps/mps_preprocess.py @@ -32,6 +32,9 @@ CompileSpec, PreprocessResult, ) + +from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass +from executorch.exir.program._program import _transform from torch.export.exported_program import ExportedProgram FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -83,6 +86,9 @@ def preprocess( # FlatBuffer graph, process the `output` nodes and add their id to # the `output_ids` array in the schema. + # TODO: Remove this once we have a better support for the dim-order ops. + edge_program = _transform(edge_program, DimOrderOpsRevertPass()) + mps_graph = MPSGraph( version="0", mps_nodes=[], diff --git a/backends/apple/mps/operators/constant_ops.py b/backends/apple/mps/operators/constant_ops.py index dacb09215c..f517c5ab46 100644 --- a/backends/apple/mps/operators/constant_ops.py +++ b/backends/apple/mps/operators/constant_ops.py @@ -78,6 +78,22 @@ def define_node( ) ) +@register_node_visitor +class ToDimOrderEmptyVisitor(NodeVisitor): + target = ["dim_order_ops._empty_dim_order.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + # We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op + # But if we do, we can't handle it ATM, so raise an exception + raise NotImplementedError("dim_order_ops._empty_dim_order.default is not supported yet") + @register_node_visitor class FullLikeVisitor(NodeVisitor): diff --git a/backends/apple/mps/operators/op_clone.py b/backends/apple/mps/operators/op_clone.py index 2310ae02da..b2e6c131dc 100644 --- a/backends/apple/mps/operators/op_clone.py +++ b/backends/apple/mps/operators/op_clone.py @@ -33,3 +33,19 @@ def define_node( ) input_id = self.define_tensor(get_input_node(node, 0), mps_graph) self.tensor_to_id[node] = input_id + +@register_node_visitor +class ToDimOrderCopyVisitor(NodeVisitor): + target = ["dim_order_ops._to_dim_order_copy.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + # We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op + # But if we do, we can't handle it ATM, so raise an exception + raise NotImplementedError("dim_order_ops._to_dim_order_copy.default is not supported yet") diff --git a/backends/apple/mps/test/test_mps.py b/backends/apple/mps/test/test_mps.py index fe64a30f3c..31f8d3be1b 100644 --- a/backends/apple/mps/test/test_mps.py +++ b/backends/apple/mps/test/test_mps.py @@ -1829,6 +1829,16 @@ def forward(self, x): Clone(), model_inputs, func_name=inspect.stack()[0].function[5:] ) + def test_mps_backend_to_copy(self): + class Copy(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._to_copy.default(x + 2, memory_format=torch.contiguous_format) + x + + model_inputs = (torch.randn(1, 3, 3),) + self.lower_and_test_with_partitioner( + Copy(), model_inputs, func_name=inspect.stack()[0].function[5:] + ) + def test_mps_backend_floor(self): class Floor(torch.nn.Module): def forward(self, x): diff --git a/backends/apple/mps/test/test_mps_utils.py b/backends/apple/mps/test/test_mps_utils.py index 43ae9aa0f0..a052e0df68 100644 --- a/backends/apple/mps/test/test_mps_utils.py +++ b/backends/apple/mps/test/test_mps_utils.py @@ -219,7 +219,7 @@ def lower_module_and_test_output( dynamic_shapes=dynamic_shapes, edge_compile_config=EdgeCompileConfig( _check_ir_validity=False, - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. + _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. ), )