diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index 220db37371..39910f0150 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -8,7 +8,7 @@ python_library( typing = True, deps = [ ":arm_backend", - "//executorch/backends/arm/passes:passes", + "//executorch/backends/arm/_passes:passes", "//executorch/exir:lib", ], ) @@ -27,7 +27,7 @@ python_library( ":arm_vela", "//executorch/backends/arm/operators:lib", "//executorch/backends/arm/operators:node_visitor", - "//executorch/backends/arm/passes:passes", + "//executorch/backends/arm/_passes:passes", ], ) diff --git a/backends/arm/passes/TARGETS b/backends/arm/_passes/TARGETS similarity index 100% rename from backends/arm/passes/TARGETS rename to backends/arm/_passes/TARGETS diff --git a/backends/arm/passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py similarity index 100% rename from backends/arm/passes/annotate_channels_last_dim_order_pass.py rename to backends/arm/_passes/annotate_channels_last_dim_order_pass.py diff --git a/backends/arm/passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py similarity index 75% rename from backends/arm/passes/arm_pass_manager.py rename to backends/arm/_passes/arm_pass_manager.py index 75ef551171..614caf1e92 100644 --- a/backends/arm/passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -8,20 +8,20 @@ # pyre-unsafe import torch -from executorch.backends.arm.passes.annotate_channels_last_dim_order_pass import ( +from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import ( AnnotateChannelsLastDimOrder, ) -from executorch.backends.arm.passes.convert_expand_copy_to_repeat import ( +from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, ) -from executorch.backends.arm.passes.convert_split_to_slice import ( +from executorch.backends.arm._passes.convert_split_to_slice import ( ConvertSplitToSlicePass, ) -from executorch.backends.arm.passes.meandim_to_averagepool_pass import ( +from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( ConvertMeanDimToAveragePool, ) -from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass -from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass +from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass +from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.pass_manager import PassManager diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py new file mode 100644 index 0000000000..34704d2ced --- /dev/null +++ b/backends/arm/_passes/arm_pass_utils.py @@ -0,0 +1,66 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from torch._ops import OpOverload + + +def create_node( + graph: torch.fx.Graph, + op_target: OpOverload, + args: tuple = (), + kwargs: Optional[dict] = None, + quantize: bool = False, + q_params: Optional[tuple] = None, +): + """ + Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node. + If quantize is true and q_params is not None, a q dq pair is inserted after the newly created node. + """ + + node = graph.create_node( + "call_function", + op_target, + args=args, + kwargs=kwargs or {}, + ) + if quantize and q_params: + return insert_q_dq_pair(graph, node, q_params) + return node + + +def insert_q_dq_pair( + graph: torch.fx.Graph, + anchor: torch.fx.Node, + q_params: tuple, +): + """ + Inserts a q dq node pair after the node 'anchor'. + """ + + with graph.inserting_after(anchor): + q = create_node( + graph=graph, + op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(), # We add the argument last + ) + q.meta = anchor.meta + with graph.inserting_after(q): + dq = create_node( + graph=graph, + op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(q,) + q_params, + ) + dq.meta = q.meta + anchor.replace_all_uses_with(dq) + # We add this last so the replace all uses above does not replace the quantized + # node's first use + q.args = (anchor,) + q_params + return dq diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py new file mode 100644 index 0000000000..6bdbca6287 --- /dev/null +++ b/backends/arm/_passes/cast_int64_pass.py @@ -0,0 +1,35 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + + +class CastInt64ToInt32Pass(ExportPass): + def __init__(self, exported_program: torch.export.ExportedProgram): + super(CastInt64ToInt32Pass, self).__init__() + self.exported_program = exported_program + + def _to_int32(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + fake_tensor = node.meta["val"] + if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): + if node.meta["val"].dtype == torch.int64: + node.meta["val"] = node.meta["val"].to(torch.int32) + buffer_name = ( + self.exported_program.graph_signature.inputs_to_buffers[ + node.name + ] + ) + new_tensor = self.exported_program.state_dict[buffer_name].to( + torch.int32 + ) + self.exported_program.state_dict[buffer_name] = new_tensor + + def call(self, graph_module: torch.fx.GraphModule): + self._to_int32(graph_module) + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/arm/passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py similarity index 100% rename from backends/arm/passes/convert_expand_copy_to_repeat.py rename to backends/arm/_passes/convert_expand_copy_to_repeat.py diff --git a/backends/arm/passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py similarity index 100% rename from backends/arm/passes/convert_split_to_slice.py rename to backends/arm/_passes/convert_split_to_slice.py diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py new file mode 100644 index 0000000000..13ee8d8dff --- /dev/null +++ b/backends/arm/_passes/decompose_div_pass.py @@ -0,0 +1,45 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +def get_div_decomposition(op) -> tuple: + """ + Returns the the (reciprocal_op, mul_op), where the ops depends on if + the div op is in exir_ops torch.ops.aten. + """ + if op == exir_ops.edge.aten.div.Tensor: + return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor) + if op == torch.ops.aten.div.Tensor: + return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor) + raise RuntimeError(f"Can't get div decomposition for op {op}") + + +class DecomposeDivPass(ExportPass): + """ + This pass decomposes div into a mul and a reciprocal node. + + Example: + y = div(a,b) + Becomes: + x = reciprocal(b) + y = mul(a,x) + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor): + return super().call_operator(op, args, kwargs, meta) + + reciprocal_op, mul_op = get_div_decomposition(op) + + numerator = args[0] + denominator = args[1] + reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta) + + return super().call_operator(mul_op, (numerator, reciprocal), {}, meta) diff --git a/backends/arm/passes/meandim_to_averagepool_pass.py b/backends/arm/_passes/meandim_to_averagepool_pass.py similarity index 100% rename from backends/arm/passes/meandim_to_averagepool_pass.py rename to backends/arm/_passes/meandim_to_averagepool_pass.py diff --git a/backends/arm/passes/remove_clone_pass.py b/backends/arm/_passes/remove_clone_pass.py similarity index 100% rename from backends/arm/passes/remove_clone_pass.py rename to backends/arm/_passes/remove_clone_pass.py diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py new file mode 100644 index 0000000000..e9e547b9c9 --- /dev/null +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -0,0 +1,69 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Union + +import torch +from executorch.backends.arm.tosa_mapping import extract_tensor_meta + +from executorch.exir.pass_base import ExportPass, PassResult +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.fx import GraphModule, Node + + +class ScalarsToAttributePass(ExportPass): + """ + For ops in 'targeted_ops', convert inputs that are scalar values + to attribute Nodes that output the same value. + """ + + targeted_ops = [ + torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.sub_.Tensor, + torch.ops.aten.mul.Tensor, + torch.ops.aten.div.Tensor, + ] + + def call(self, graph_module: GraphModule) -> PassResult: + for n in graph_module.graph.nodes: + n = cast(Node, n) + if n.op != "call_function" or n.target not in self.targeted_ops: + continue + + biggest_rank = 1 + for arg in n.args: + if isinstance(arg, Node): + _, shape, _ = extract_tensor_meta(arg.meta) + biggest_rank = max(biggest_rank, len(shape)) + + new_args = [] + for arg in n.args: + if isinstance(arg, Node): + new_args.append(arg) + continue + + prefix = "_tensor_constant_" + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + tensor_constant_name = get_new_attr_name(graph_module) + float_tensor = torch.tensor( + float(cast(Union[int, float], arg)) + ).reshape((1,) * biggest_rank) + graph_module.register_buffer(tensor_constant_name, float_tensor) + fake_mode = n.meta["val"].fake_mode + + with graph_module.graph.inserting_before(n): + get_attr_node = graph_module.graph.create_node( + "get_attr", tensor_constant_name, (), {} + ) + get_attr_node.meta["val"] = fake_mode.from_tensor( + float_tensor, static_shapes=True + ) + new_args.append(get_attr_node) + n.args = tuple(new_args) + + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/arm/passes/size_adjust_conv2d_pass.py b/backends/arm/_passes/size_adjust_conv2d_pass.py similarity index 100% rename from backends/arm/passes/size_adjust_conv2d_pass.py rename to backends/arm/_passes/size_adjust_conv2d_pass.py diff --git a/backends/arm/passes/tag_io_quant_pass.py b/backends/arm/_passes/tag_io_quant_pass.py similarity index 100% rename from backends/arm/passes/tag_io_quant_pass.py rename to backends/arm/_passes/tag_io_quant_pass.py diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index a5f47c222f..c3e2e84c6a 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -20,7 +20,9 @@ from executorch.backends.arm.operators.node_visitor import get_node_visitors from executorch.backends.arm.operators.op_output import process_output from executorch.backends.arm.operators.op_placeholder import process_placeholder -from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager +from executorch.backends.arm._passes.arm_pass_manager import ( + ArmPassManager, +) # usort: skip from executorch.backends.arm.tosa_utils import ( dbg_fail, dbg_tosa_dump, diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 6b57c3d965..7793060f6f 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -11,8 +11,8 @@ from typing import final, List import torch -from executorch.backends.arm.arm_backend import ArmBackend -from executorch.backends.arm.passes.tag_io_quant_pass import TagIOQuantPass +from executorch.backends.arm.arm_backend import ArmBackend # usort: skip +from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, diff --git a/backends/arm/test/passes/test_meandim_to_averagepool2d.py b/backends/arm/test/passes/test_meandim_to_averagepool2d.py index 1cd63e6e52..c8fa0f4b7a 100644 --- a/backends/arm/test/passes/test_meandim_to_averagepool2d.py +++ b/backends/arm/test/passes/test_meandim_to_averagepool2d.py @@ -7,7 +7,7 @@ import unittest import torch -from executorch.backends.arm.passes.meandim_to_averagepool_pass import ( +from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( ConvertMeanDimToAveragePool, )