Skip to content

Commit

Permalink
Fix linter errors in Core ML partitioner. (pytorch#1630)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1630

.

Reviewed By: GregoryComer

Differential Revision: D52860966

fbshipit-source-id: d8d9eb62e5fddf155c7c766b3259e7a9aee8640f
  • Loading branch information
shoumikhin authored and facebook-github-bot committed Jan 18, 2024
1 parent 33f8887 commit cfa27d6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
25 changes: 16 additions & 9 deletions backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,31 @@
# Please refer to the license found in the LICENSE file in the root directory of the source tree.

import logging
from typing import List
from typing import List, Optional

import coremltools as ct

import torch
from torch._export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase

from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.backends.apple.coreml.compiler.coreml_preprocess import CoreMLBackend

import coremltools as ct
from torch._export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


class OperatorsSupportedForCoreMLBackend(OperatorSupportBase):
def __init__(self, skip_ops_for_coreml_delegation: List[str] = []) -> None:
def __init__(
self, skip_ops_for_coreml_delegation: Optional[List[str]] = None
) -> None:
if skip_ops_for_coreml_delegation is None:
skip_ops_for_coreml_delegation = []
super().__init__()
self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation

Expand All @@ -51,7 +54,11 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
class CoreMLPartitioner(Partitioner):
compile_spec = []

def __init__(self, skip_ops_for_coreml_delegation: List[str] = []) -> None:
def __init__(
self, skip_ops_for_coreml_delegation: Optional[List[str]] = None
) -> None:
if skip_ops_for_coreml_delegation is None:
skip_ops_for_coreml_delegation = []
self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation
self.delegation_spec = DelegationSpec("CoreMLBackend", self.compile_spec)

Expand Down
33 changes: 25 additions & 8 deletions backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

import unittest

import executorch.exir as exir

import torch

import executorch.exir as exir
from executorch.backends.apple.coreml.partition.coreml_partitioner import (
CoreMLPartitioner,
)
from executorch.exir.backend.backend_api import to_backend

from executorch.backends.apple.coreml.partition.coreml_partitioner import CoreMLPartitioner


class TestCoreMLPartitioner(unittest.TestCase):
def test_partition_add_mul(self):
Expand All @@ -28,12 +30,20 @@ def forward(self, a, x, b):

model = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
exported_program = exir.capture(model, inputs, exir.CaptureConfig()).to_edge().exported_program
exported_program = (
exir.capture(model, inputs, exir.CaptureConfig()).to_edge().exported_program
)

assert [
node.target.__name__ for node in exported_program.graph.nodes if node.op == "call_function"
node.target.__name__
for node in exported_program.graph.nodes
if node.op == "call_function"
] == [
"aten.mm.default", "aten.add.Tensor", "aten.sub.Tensor", "aten.mm.default", "aten.add.Tensor"
"aten.mm.default",
"aten.add.Tensor",
"aten.sub.Tensor",
"aten.mm.default",
"aten.add.Tensor",
]

exported_to_coreml = to_backend(
Expand All @@ -42,9 +52,16 @@ def forward(self, a, x, b):
)

assert [
node.target.__name__ for node in exported_to_coreml.graph.nodes if node.op == "call_function"
node.target.__name__
for node in exported_to_coreml.graph.nodes
if node.op == "call_function"
] == [
"aten.mm.default", "executorch_call_delegate", "getitem", "aten.mm.default", "executorch_call_delegate", "getitem"
"aten.mm.default",
"executorch_call_delegate",
"getitem",
"aten.mm.default",
"executorch_call_delegate",
"getitem",
]


Expand Down

0 comments on commit cfa27d6

Please sign in to comment.