-
Notifications
You must be signed in to change notification settings - Fork 409
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Summary: Changing arm.passes to arm._passes to indicate that these passes are not covered under the API stability guarantee. Pull Request resolved: #5918 Reviewed By: malfet, helunwencser Differential Revision: D63926055 fbshipit-source-id: 141a5be9f3a81e75784825357bacbab91904620c (cherry picked from commit 83c95df)
- Loading branch information
Showing
17 changed files
with
229 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
|
||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from torch._ops import OpOverload | ||
|
||
|
||
def create_node( | ||
graph: torch.fx.Graph, | ||
op_target: OpOverload, | ||
args: tuple = (), | ||
kwargs: Optional[dict] = None, | ||
quantize: bool = False, | ||
q_params: Optional[tuple] = None, | ||
): | ||
""" | ||
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node. | ||
If quantize is true and q_params is not None, a q dq pair is inserted after the newly created node. | ||
""" | ||
|
||
node = graph.create_node( | ||
"call_function", | ||
op_target, | ||
args=args, | ||
kwargs=kwargs or {}, | ||
) | ||
if quantize and q_params: | ||
return insert_q_dq_pair(graph, node, q_params) | ||
return node | ||
|
||
|
||
def insert_q_dq_pair( | ||
graph: torch.fx.Graph, | ||
anchor: torch.fx.Node, | ||
q_params: tuple, | ||
): | ||
""" | ||
Inserts a q dq node pair after the node 'anchor'. | ||
""" | ||
|
||
with graph.inserting_after(anchor): | ||
q = create_node( | ||
graph=graph, | ||
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, | ||
args=(), # We add the argument last | ||
) | ||
q.meta = anchor.meta | ||
with graph.inserting_after(q): | ||
dq = create_node( | ||
graph=graph, | ||
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, | ||
args=(q,) + q_params, | ||
) | ||
dq.meta = q.meta | ||
anchor.replace_all_uses_with(dq) | ||
# We add this last so the replace all uses above does not replace the quantized | ||
# node's first use | ||
q.args = (anchor,) + q_params | ||
return dq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
|
||
|
||
class CastInt64ToInt32Pass(ExportPass): | ||
def __init__(self, exported_program: torch.export.ExportedProgram): | ||
super(CastInt64ToInt32Pass, self).__init__() | ||
self.exported_program = exported_program | ||
|
||
def _to_int32(self, graph_module: torch.fx.GraphModule): | ||
for node in graph_module.graph.nodes: | ||
fake_tensor = node.meta["val"] | ||
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): | ||
if node.meta["val"].dtype == torch.int64: | ||
node.meta["val"] = node.meta["val"].to(torch.int32) | ||
buffer_name = ( | ||
self.exported_program.graph_signature.inputs_to_buffers[ | ||
node.name | ||
] | ||
) | ||
new_tensor = self.exported_program.state_dict[buffer_name].to( | ||
torch.int32 | ||
) | ||
self.exported_program.state_dict[buffer_name] = new_tensor | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
self._to_int32(graph_module) | ||
graph_module.recompile() | ||
graph_module = super().call(graph_module).graph_module | ||
return PassResult(graph_module, True) |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.pass_base import ExportPass | ||
|
||
|
||
def get_div_decomposition(op) -> tuple: | ||
""" | ||
Returns the the (reciprocal_op, mul_op), where the ops depends on if | ||
the div op is in exir_ops torch.ops.aten. | ||
""" | ||
if op == exir_ops.edge.aten.div.Tensor: | ||
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor) | ||
if op == torch.ops.aten.div.Tensor: | ||
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor) | ||
raise RuntimeError(f"Can't get div decomposition for op {op}") | ||
|
||
|
||
class DecomposeDivPass(ExportPass): | ||
""" | ||
This pass decomposes div into a mul and a reciprocal node. | ||
Example: | ||
y = div(a,b) | ||
Becomes: | ||
x = reciprocal(b) | ||
y = mul(a,x) | ||
""" | ||
|
||
def call_operator(self, op, args, kwargs, meta): | ||
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor): | ||
return super().call_operator(op, args, kwargs, meta) | ||
|
||
reciprocal_op, mul_op = get_div_decomposition(op) | ||
|
||
numerator = args[0] | ||
denominator = args[1] | ||
reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta) | ||
|
||
return super().call_operator(mul_op, (numerator, reciprocal), {}, meta) |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import cast, Union | ||
|
||
import torch | ||
from executorch.backends.arm.tosa_mapping import extract_tensor_meta | ||
|
||
from executorch.exir.pass_base import ExportPass, PassResult | ||
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix | ||
from torch.fx import GraphModule, Node | ||
|
||
|
||
class ScalarsToAttributePass(ExportPass): | ||
""" | ||
For ops in 'targeted_ops', convert inputs that are scalar values | ||
to attribute Nodes that output the same value. | ||
""" | ||
|
||
targeted_ops = [ | ||
torch.ops.aten.add.Tensor, | ||
torch.ops.aten.sub.Tensor, | ||
torch.ops.aten.sub_.Tensor, | ||
torch.ops.aten.mul.Tensor, | ||
torch.ops.aten.div.Tensor, | ||
] | ||
|
||
def call(self, graph_module: GraphModule) -> PassResult: | ||
for n in graph_module.graph.nodes: | ||
n = cast(Node, n) | ||
if n.op != "call_function" or n.target not in self.targeted_ops: | ||
continue | ||
|
||
biggest_rank = 1 | ||
for arg in n.args: | ||
if isinstance(arg, Node): | ||
_, shape, _ = extract_tensor_meta(arg.meta) | ||
biggest_rank = max(biggest_rank, len(shape)) | ||
|
||
new_args = [] | ||
for arg in n.args: | ||
if isinstance(arg, Node): | ||
new_args.append(arg) | ||
continue | ||
|
||
prefix = "_tensor_constant_" | ||
get_new_attr_name = get_new_attr_name_with_prefix(prefix) | ||
tensor_constant_name = get_new_attr_name(graph_module) | ||
float_tensor = torch.tensor( | ||
float(cast(Union[int, float], arg)) | ||
).reshape((1,) * biggest_rank) | ||
graph_module.register_buffer(tensor_constant_name, float_tensor) | ||
fake_mode = n.meta["val"].fake_mode | ||
|
||
with graph_module.graph.inserting_before(n): | ||
get_attr_node = graph_module.graph.create_node( | ||
"get_attr", tensor_constant_name, (), {} | ||
) | ||
get_attr_node.meta["val"] = fake_mode.from_tensor( | ||
float_tensor, static_shapes=True | ||
) | ||
new_args.append(get_attr_node) | ||
n.args = tuple(new_args) | ||
|
||
graph_module.recompile() | ||
return PassResult(graph_module, True) |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters