From 32ebc557d692e6e7d28d5dc799f21c2afdcd654e Mon Sep 17 00:00:00 2001 From: Yifan Shen Date: Wed, 20 Nov 2024 23:11:30 -0800 Subject: [PATCH] [Torch.Export] Support Some New Ops Translation (#2403) * minor polish: move nan_to_num helper inside it * support new ops: nan_to_num, cumprod, searchsorted, one_hot --------- Co-authored-by: yifan_shen3 --- .../converters/mil/frontend/torch/ops.py | 181 ++++++++++++++---- .../mil/frontend/torch/test/test_torch_ops.py | 55 ++++-- 2 files changed, 181 insertions(+), 55 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 8cba474ef..8610b435b 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -78,6 +78,10 @@ # norm "fro", + + # searchsorted side + "left", + "right", } @@ -239,6 +243,7 @@ def _get_bindings(context, alist) -> List[Var]: elif i in TORCH_STRING_ARGS: results.append(i) else: + results.append(i) logger.warning( f"Binding {i} is neither a name of exisitng var in context, " "nor a torch string argument." @@ -8293,33 +8298,79 @@ def linalg_inv(context, node): context.add(mb.const(val=np.linalg.inv(x.val), name=node.name)) -def _replace_values_by_bool_mask(data: Var, mask: Var, new_value: Union[int, float]): - """Replace the position in data where mask has True element to new_value.""" - indices = mb.non_zero(x=mb.cast(x=mask, dtype="int32")) - # If there is no replacement needed, just use identity op. - if 0 in indices.shape: - return mb.identity(x=data) - - # Expand the replacement value to the compatible shape for scatter_nd. - replacement_values = mb.expand_dims(x=new_value, axes=[0]) - reps = mb.expand_dims(x=value_at(mb.shape(x=indices), 0), axes=[0]) - replacement_values = mb.tile(x=replacement_values, reps=reps) - - # Replace all nan to the corresponding values. - return mb.scatter_nd(data=data, indices=indices, updates=replacement_values, mode="update") +@register_torch_op +def isnan(context, node): + x = _get_inputs(context, node, expected=1)[0] + # Find indices of NaN based on "NaN is never equal to itself". + nan_indices = mb.not_equal(x=x, y=x, name=node.name) + context.add(nan_indices) @register_torch_op def nan_to_num(context, node): - inputs = _get_inputs(context, node, expected=4) - x = inputs[0] - nan = inputs[1].val if inputs[1] is not None else 0.0 - posinf = inputs[2].val if inputs[2] is not None else None - neginf = inputs[3].val if inputs[3] is not None else None - if posinf is None: - posinf = types.type_mapping.builtin_to_range(x.dtype).high - if neginf is None: - neginf = types.type_mapping.builtin_to_range(x.dtype).low + def _parse_positional_args(context, node) -> Tuple[Var]: + inputs = _get_inputs( + context, + node, + expected={TorchFrontend.TORCHSCRIPT: 4}, + min_expected={TorchFrontend.TORCHEXPORT: 1, TorchFrontend.EXECUTORCH: 1}, + ) + nargs = len(inputs) + + x = inputs[0] + nan = inputs[1] if nargs > 1 else None + posinf = inputs[2] if nargs > 2 else None + neginf = inputs[3] if nargs > 3 else None + + return x, nan, posinf, neginf + + def _parse_keyword_args(context, node, nan, posinf, neginf) -> Tuple[Var]: + # Only torch.export may have kwargs + if context.frontend not in TORCH_EXPORT_BASED_FRONTENDS: + return nan, posinf, neginf + + nan = _get_kwinputs(context, node, "nan", default=[nan])[0] + posinf = _get_kwinputs(context, node, "posinf", default=[posinf])[0] + neginf = _get_kwinputs(context, node, "neginf", default=[neginf])[0] + + return nan, posinf, neginf + + def _translate_torch_args(x, nan, posinf, neginf) -> Tuple[Var]: + if nan is None: + nan = 0.0 + else: + if isinstance(nan, Var): + nan = nan.val + + if posinf is None: + posinf = types.type_mapping.builtin_to_range(x.dtype).high + else: + if isinstance(posinf, Var): + posinf = posinf.val + + if neginf is None: + neginf = types.type_mapping.builtin_to_range(x.dtype).low + else: + if isinstance(neginf, Var): + neginf = neginf.val + + return nan, posinf, neginf + + def _replace_values_by_bool_mask(data: Var, mask: Var, new_value: Union[int, float]): + """Replace the position in data where mask has True element to new_value.""" + indices = mb.non_zero(x=mb.cast(x=mask, dtype="int32")) + + # Expand the replacement value to the compatible shape for scatter_nd. + replacement_values = mb.expand_dims(x=new_value, axes=[0]) + reps = mb.expand_dims(x=value_at(mb.shape(x=indices), 0), axes=[0]) + replacement_values = mb.tile(x=replacement_values, reps=reps) + + # Replace all nan to the corresponding values. + return mb.scatter_nd(data=data, indices=indices, updates=replacement_values, mode="update") + + x, nan, posinf, neginf = _parse_positional_args(context, node) + nan, posinf, neginf = _parse_keyword_args(context, node, nan, posinf, neginf) + nan, posinf, neginf = _translate_torch_args(x, nan, posinf, neginf) if x.val is not None: res = mb.const(val=np.nan_to_num(x.val, nan=nan, posinf=posinf, neginf=neginf)) @@ -8341,9 +8392,10 @@ def nan_to_num(context, node): @register_torch_op def cumprod(context, node): - inputs = _get_inputs(context, node, expected=3) + inputs = _get_inputs(context, node, min_expected=2) x = inputs[0] dim = inputs[1].val + # dtype may be the 3rd input, but we will not use it size = x.shape[dim] if is_symbolic(size): @@ -8360,16 +8412,52 @@ def cumprod(context, node): @register_torch_op def searchsorted(context, node): - inputs = _get_inputs(context, node, expected=6) - sorted_sequence = inputs[0] - values = inputs[1] - side = inputs[4].val if inputs[4] is not None else False - if side is not None: - # The `side` parameter is preferred than `right` in torch. - right = side == "right" - else: - # If side is not specified, use the `right` parameter to determine. - right = inputs[3].val if inputs[3] is not None else False + def _parse_positional_args(context, node) -> Tuple[Var]: + inputs = _get_inputs( + context, + node, + expected={TorchFrontend.TORCHSCRIPT: 6}, + min_expected={TorchFrontend.TORCHEXPORT: 2, TorchFrontend.EXECUTORCH: 2}, + ) + nargs = len(inputs) + + sorted_sequence = inputs[0] + values = inputs[1] + # we will not use `out_int32` + right = inputs[3] if nargs > 3 else False + side = inputs[4] if nargs > 4 else None + # we will not use `sorter` + + return sorted_sequence, values, right, side + + def _parse_keyword_args(context, node, right, side) -> Tuple[Var]: + # Only torch.export may have kwargs + if context.frontend not in TORCH_EXPORT_BASED_FRONTENDS: + return right, side + + right = _get_kwinputs(context, node, "right", default=[right])[0] + side = _get_kwinputs(context, node, "side", default=[side])[0] + + return right, side + + def _translate_torch_args(right, side) -> Tuple[Var]: + if side is not None: + if isinstance(side, Var): + side = side.val + # The `side` parameter is preferred than `right` in torch. + right = side == "right" + else: + # If side is not specified, use the `right` parameter to determine. + if right is None: + right = False + else: + if isinstance(right, Var): + right = right.val + return right + + sorted_sequence, values, right, side = _parse_positional_args(context, node) + right, side = _parse_keyword_args(context, node, right, side) + right = _translate_torch_args(right, side) if sorted_sequence.rank != values.rank: raise NotImplementedError( @@ -8399,9 +8487,28 @@ def searchsorted(context, node): @register_torch_op def one_hot(context, node): - inputs = _get_inputs(context, node, expected=2) - labels = inputs[0] - num_classes = inputs[1].val + def _parse_positional_args(context, node) -> Tuple[Var]: + inputs = _get_inputs(context, node, expected=(1, 2)) + nargs = len(inputs) + + labels = inputs[0] + num_classes = inputs[1] if nargs > 1 else -1 + + return labels, num_classes + + def _parse_keyword_args(context, node, num_classes) -> Var: + # Only torch.export may have kwargs + if context.frontend not in TORCH_EXPORT_BASED_FRONTENDS: + return num_classes + + num_classes = _get_kwinputs(context, node, "num_classes", default=[num_classes])[0] + + return num_classes + + labels, num_classes = _parse_positional_args(context, node) + num_classes = _parse_keyword_args(context, node, num_classes) + if isinstance(num_classes, Var): + num_classes = num_classes.val res = mb.one_hot(indices=labels, one_hot_vector_size=num_classes, name=node.name) context.add(res) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 1db603a3e..18c42b61f 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -13606,10 +13606,12 @@ def forward(self, x): class TestNanToNum(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, nan, posinf, neginf", - itertools.product(compute_units, backends, [None, 1.0], [None, 1000.0], [None, -1000.0]), + "compute_unit, backend, frontend, nan, posinf, neginf", + itertools.product( + compute_units, backends, frontends, [None, 1.0], [None, 1000.0], [None, -1000.0] + ), ) - def test_nan_to_num_const(self, compute_unit, backend, nan, posinf, neginf): + def test_nan_to_num_const(self, compute_unit, backend, frontend, nan, posinf, neginf): class TestModel(nn.Module): def forward(self, x): input_data = torch.tensor([float("nan"), float("inf"), -float("inf"), 3.14]) @@ -13618,15 +13620,16 @@ def forward(self, x): self.run_compare_torch( (2, 3), TestModel(), + frontend=frontend, backend=backend, compute_unit=compute_unit, ) @pytest.mark.parametrize( - "compute_unit, backend, nan, posinf, neginf", - itertools.product(compute_units, backends, [None, 1.0], [1000.0], [-1000.0]), + "compute_unit, backend, frontend, nan, posinf, neginf", + itertools.product(compute_units, backends, frontends, [None, 1.0], [1000.0], [-1000.0]), ) - def test_nan_to_num_non_const(self, compute_unit, backend, nan, posinf, neginf): + def test_nan_to_num_non_const(self, compute_unit, backend, frontend, nan, posinf, neginf): class TestModel(nn.Module): def forward(self, x): return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) @@ -13635,6 +13638,7 @@ def forward(self, x): self.run_compare_torch( input_data, TestModel(), + frontend=frontend, backend=backend, compute_unit=compute_unit, input_as_shape=False, @@ -13643,10 +13647,13 @@ def forward(self, x): class TestCumprod(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, axis", - itertools.product(compute_units, backends, [0, 1, 2, -1]), + "compute_unit, backend, frontend, axis", + itertools.product(compute_units, backends, frontends, [0, 1, 2, -1]), ) - def test_cumprod(self, compute_unit, backend, axis): + def test_cumprod(self, compute_unit, backend, frontend, axis): + if frontend == TorchFrontend.EXECUTORCH: + pytest.skip("torch._ops.aten.cumprod.default is not Aten Canonical") + class TestModel(nn.Module): def forward(self, x): return torch.cumprod(x, axis) @@ -13654,6 +13661,7 @@ def forward(self, x): self.run_compare_torch( (2, 3, 4), TestModel(), + frontend=frontend, backend=backend, compute_unit=compute_unit, ) @@ -13661,12 +13669,15 @@ def forward(self, x): class TestSearchsorted(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, side", - itertools.product(compute_units, backends, [None, "left", "right"]), + "compute_unit, backend, frontend, side", + itertools.product(compute_units, backends, frontends, [None, "left", "right"]), ) - def test_searchsorted_basic(self, compute_unit, backend, side): + def test_searchsorted_basic(self, compute_unit, backend, frontend, side): """This is the test case same as PyTorch doc for `torch.searchsorted`.""" + if frontend == TorchFrontend.EXECUTORCH: + pytest.skip("torch._ops.aten.searchsorted.Tensor is not Aten Canonical") + class TestModel(nn.Module): def forward(self, input_data, values): return torch.searchsorted(input_data, values, side=side) @@ -13676,16 +13687,22 @@ def forward(self, input_data, values): self.run_compare_torch( (input_data, values), TestModel(), + frontend=frontend, backend=backend, compute_unit=compute_unit, input_as_shape=False, ) @pytest.mark.parametrize( - "compute_unit, backend, values_shape, side", - itertools.product(compute_units, backends, [(2, 1), (2, 10)], [None, "left", "right"]), + "compute_unit, backend, frontend, values_shape, side", + itertools.product( + compute_units, backends, frontends, [(2, 1), (2, 10)], [None, "left", "right"] + ), ) - def test_searchsorted_stress(self, compute_unit, backend, values_shape, side): + def test_searchsorted_stress(self, compute_unit, backend, frontend, values_shape, side): + if frontend == TorchFrontend.EXECUTORCH: + pytest.skip("torch._ops.aten.searchsorted.Tensor is not Aten Canonical") + class TestModel(nn.Module): def forward(self, input_data, values): return torch.searchsorted(input_data, values, side=side) @@ -13695,6 +13712,7 @@ def forward(self, input_data, values): self.run_compare_torch( (input_data, values), TestModel(), + frontend=frontend, backend=backend, compute_unit=compute_unit, input_as_shape=False, @@ -13703,16 +13721,17 @@ def forward(self, input_data, values): class TestOneHot(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, num_classes, rank", - itertools.product(compute_units, backends, range(1, 5), range(1, 5)), + "compute_unit, backend, frontend, num_classes, rank", + itertools.product(compute_units, backends, frontends, range(1, 5), range(1, 5)), ) - def test_one_hot(self, compute_unit, backend, num_classes, rank): + def test_one_hot(self, compute_unit, backend, frontend, num_classes, rank): model = ModuleWrapper(function=torch.nn.functional.one_hot, kwargs={"num_classes": num_classes}).eval() shape = torch.randint(1, 10, (rank,)).tolist() labels = torch.randint(0, num_classes, shape) self.run_compare_torch( torch.LongTensor(labels), model, + frontend=frontend, backend=backend, compute_unit=compute_unit, input_as_shape=False,