Skip to content

Commit

Permalink
[Torch.Export] Support Some New Ops Translation (#2403)
Browse files Browse the repository at this point in the history
* minor polish: move nan_to_num helper inside it

* support new ops: nan_to_num, cumprod, searchsorted, one_hot

---------

Co-authored-by: yifan_shen3 <[email protected]>
  • Loading branch information
YifanShenSZ and yifan_shen3 authored Nov 21, 2024
1 parent be29fb9 commit 32ebc55
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 55 deletions.
181 changes: 144 additions & 37 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@

# norm
"fro",

# searchsorted side
"left",
"right",
}


Expand Down Expand Up @@ -239,6 +243,7 @@ def _get_bindings(context, alist) -> List[Var]:
elif i in TORCH_STRING_ARGS:
results.append(i)
else:
results.append(i)
logger.warning(
f"Binding {i} is neither a name of exisitng var in context, "
"nor a torch string argument."
Expand Down Expand Up @@ -8293,33 +8298,79 @@ def linalg_inv(context, node):
context.add(mb.const(val=np.linalg.inv(x.val), name=node.name))


def _replace_values_by_bool_mask(data: Var, mask: Var, new_value: Union[int, float]):
"""Replace the position in data where mask has True element to new_value."""
indices = mb.non_zero(x=mb.cast(x=mask, dtype="int32"))
# If there is no replacement needed, just use identity op.
if 0 in indices.shape:
return mb.identity(x=data)

# Expand the replacement value to the compatible shape for scatter_nd.
replacement_values = mb.expand_dims(x=new_value, axes=[0])
reps = mb.expand_dims(x=value_at(mb.shape(x=indices), 0), axes=[0])
replacement_values = mb.tile(x=replacement_values, reps=reps)

# Replace all nan to the corresponding values.
return mb.scatter_nd(data=data, indices=indices, updates=replacement_values, mode="update")
@register_torch_op
def isnan(context, node):
x = _get_inputs(context, node, expected=1)[0]
# Find indices of NaN based on "NaN is never equal to itself".
nan_indices = mb.not_equal(x=x, y=x, name=node.name)
context.add(nan_indices)


@register_torch_op
def nan_to_num(context, node):
inputs = _get_inputs(context, node, expected=4)
x = inputs[0]
nan = inputs[1].val if inputs[1] is not None else 0.0
posinf = inputs[2].val if inputs[2] is not None else None
neginf = inputs[3].val if inputs[3] is not None else None
if posinf is None:
posinf = types.type_mapping.builtin_to_range(x.dtype).high
if neginf is None:
neginf = types.type_mapping.builtin_to_range(x.dtype).low
def _parse_positional_args(context, node) -> Tuple[Var]:
inputs = _get_inputs(
context,
node,
expected={TorchFrontend.TORCHSCRIPT: 4},
min_expected={TorchFrontend.TORCHEXPORT: 1, TorchFrontend.EXECUTORCH: 1},
)
nargs = len(inputs)

x = inputs[0]
nan = inputs[1] if nargs > 1 else None
posinf = inputs[2] if nargs > 2 else None
neginf = inputs[3] if nargs > 3 else None

return x, nan, posinf, neginf

def _parse_keyword_args(context, node, nan, posinf, neginf) -> Tuple[Var]:
# Only torch.export may have kwargs
if context.frontend not in TORCH_EXPORT_BASED_FRONTENDS:
return nan, posinf, neginf

nan = _get_kwinputs(context, node, "nan", default=[nan])[0]
posinf = _get_kwinputs(context, node, "posinf", default=[posinf])[0]
neginf = _get_kwinputs(context, node, "neginf", default=[neginf])[0]

return nan, posinf, neginf

def _translate_torch_args(x, nan, posinf, neginf) -> Tuple[Var]:
if nan is None:
nan = 0.0
else:
if isinstance(nan, Var):
nan = nan.val

if posinf is None:
posinf = types.type_mapping.builtin_to_range(x.dtype).high
else:
if isinstance(posinf, Var):
posinf = posinf.val

if neginf is None:
neginf = types.type_mapping.builtin_to_range(x.dtype).low
else:
if isinstance(neginf, Var):
neginf = neginf.val

return nan, posinf, neginf

def _replace_values_by_bool_mask(data: Var, mask: Var, new_value: Union[int, float]):
"""Replace the position in data where mask has True element to new_value."""
indices = mb.non_zero(x=mb.cast(x=mask, dtype="int32"))

# Expand the replacement value to the compatible shape for scatter_nd.
replacement_values = mb.expand_dims(x=new_value, axes=[0])
reps = mb.expand_dims(x=value_at(mb.shape(x=indices), 0), axes=[0])
replacement_values = mb.tile(x=replacement_values, reps=reps)

# Replace all nan to the corresponding values.
return mb.scatter_nd(data=data, indices=indices, updates=replacement_values, mode="update")

x, nan, posinf, neginf = _parse_positional_args(context, node)
nan, posinf, neginf = _parse_keyword_args(context, node, nan, posinf, neginf)
nan, posinf, neginf = _translate_torch_args(x, nan, posinf, neginf)

if x.val is not None:
res = mb.const(val=np.nan_to_num(x.val, nan=nan, posinf=posinf, neginf=neginf))
Expand All @@ -8341,9 +8392,10 @@ def nan_to_num(context, node):

@register_torch_op
def cumprod(context, node):
inputs = _get_inputs(context, node, expected=3)
inputs = _get_inputs(context, node, min_expected=2)
x = inputs[0]
dim = inputs[1].val
# dtype may be the 3rd input, but we will not use it

size = x.shape[dim]
if is_symbolic(size):
Expand All @@ -8360,16 +8412,52 @@ def cumprod(context, node):

@register_torch_op
def searchsorted(context, node):
inputs = _get_inputs(context, node, expected=6)
sorted_sequence = inputs[0]
values = inputs[1]
side = inputs[4].val if inputs[4] is not None else False
if side is not None:
# The `side` parameter is preferred than `right` in torch.
right = side == "right"
else:
# If side is not specified, use the `right` parameter to determine.
right = inputs[3].val if inputs[3] is not None else False
def _parse_positional_args(context, node) -> Tuple[Var]:
inputs = _get_inputs(
context,
node,
expected={TorchFrontend.TORCHSCRIPT: 6},
min_expected={TorchFrontend.TORCHEXPORT: 2, TorchFrontend.EXECUTORCH: 2},
)
nargs = len(inputs)

sorted_sequence = inputs[0]
values = inputs[1]
# we will not use `out_int32`
right = inputs[3] if nargs > 3 else False
side = inputs[4] if nargs > 4 else None
# we will not use `sorter`

return sorted_sequence, values, right, side

def _parse_keyword_args(context, node, right, side) -> Tuple[Var]:
# Only torch.export may have kwargs
if context.frontend not in TORCH_EXPORT_BASED_FRONTENDS:
return right, side

right = _get_kwinputs(context, node, "right", default=[right])[0]
side = _get_kwinputs(context, node, "side", default=[side])[0]

return right, side

def _translate_torch_args(right, side) -> Tuple[Var]:
if side is not None:
if isinstance(side, Var):
side = side.val
# The `side` parameter is preferred than `right` in torch.
right = side == "right"
else:
# If side is not specified, use the `right` parameter to determine.
if right is None:
right = False
else:
if isinstance(right, Var):
right = right.val
return right

sorted_sequence, values, right, side = _parse_positional_args(context, node)
right, side = _parse_keyword_args(context, node, right, side)
right = _translate_torch_args(right, side)

if sorted_sequence.rank != values.rank:
raise NotImplementedError(
Expand Down Expand Up @@ -8399,9 +8487,28 @@ def searchsorted(context, node):

@register_torch_op
def one_hot(context, node):
inputs = _get_inputs(context, node, expected=2)
labels = inputs[0]
num_classes = inputs[1].val
def _parse_positional_args(context, node) -> Tuple[Var]:
inputs = _get_inputs(context, node, expected=(1, 2))
nargs = len(inputs)

labels = inputs[0]
num_classes = inputs[1] if nargs > 1 else -1

return labels, num_classes

def _parse_keyword_args(context, node, num_classes) -> Var:
# Only torch.export may have kwargs
if context.frontend not in TORCH_EXPORT_BASED_FRONTENDS:
return num_classes

num_classes = _get_kwinputs(context, node, "num_classes", default=[num_classes])[0]

return num_classes

labels, num_classes = _parse_positional_args(context, node)
num_classes = _parse_keyword_args(context, node, num_classes)
if isinstance(num_classes, Var):
num_classes = num_classes.val

res = mb.one_hot(indices=labels, one_hot_vector_size=num_classes, name=node.name)
context.add(res)
55 changes: 37 additions & 18 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13606,10 +13606,12 @@ def forward(self, x):

class TestNanToNum(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, nan, posinf, neginf",
itertools.product(compute_units, backends, [None, 1.0], [None, 1000.0], [None, -1000.0]),
"compute_unit, backend, frontend, nan, posinf, neginf",
itertools.product(
compute_units, backends, frontends, [None, 1.0], [None, 1000.0], [None, -1000.0]
),
)
def test_nan_to_num_const(self, compute_unit, backend, nan, posinf, neginf):
def test_nan_to_num_const(self, compute_unit, backend, frontend, nan, posinf, neginf):
class TestModel(nn.Module):
def forward(self, x):
input_data = torch.tensor([float("nan"), float("inf"), -float("inf"), 3.14])
Expand All @@ -13618,15 +13620,16 @@ def forward(self, x):
self.run_compare_torch(
(2, 3),
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
)

@pytest.mark.parametrize(
"compute_unit, backend, nan, posinf, neginf",
itertools.product(compute_units, backends, [None, 1.0], [1000.0], [-1000.0]),
"compute_unit, backend, frontend, nan, posinf, neginf",
itertools.product(compute_units, backends, frontends, [None, 1.0], [1000.0], [-1000.0]),
)
def test_nan_to_num_non_const(self, compute_unit, backend, nan, posinf, neginf):
def test_nan_to_num_non_const(self, compute_unit, backend, frontend, nan, posinf, neginf):
class TestModel(nn.Module):
def forward(self, x):
return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
Expand All @@ -13635,6 +13638,7 @@ def forward(self, x):
self.run_compare_torch(
input_data,
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
Expand All @@ -13643,30 +13647,37 @@ def forward(self, x):

class TestCumprod(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, axis",
itertools.product(compute_units, backends, [0, 1, 2, -1]),
"compute_unit, backend, frontend, axis",
itertools.product(compute_units, backends, frontends, [0, 1, 2, -1]),
)
def test_cumprod(self, compute_unit, backend, axis):
def test_cumprod(self, compute_unit, backend, frontend, axis):
if frontend == TorchFrontend.EXECUTORCH:
pytest.skip("torch._ops.aten.cumprod.default is not Aten Canonical")

class TestModel(nn.Module):
def forward(self, x):
return torch.cumprod(x, axis)

self.run_compare_torch(
(2, 3, 4),
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
)


class TestSearchsorted(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, side",
itertools.product(compute_units, backends, [None, "left", "right"]),
"compute_unit, backend, frontend, side",
itertools.product(compute_units, backends, frontends, [None, "left", "right"]),
)
def test_searchsorted_basic(self, compute_unit, backend, side):
def test_searchsorted_basic(self, compute_unit, backend, frontend, side):
"""This is the test case same as PyTorch doc for `torch.searchsorted`."""

if frontend == TorchFrontend.EXECUTORCH:
pytest.skip("torch._ops.aten.searchsorted.Tensor is not Aten Canonical")

class TestModel(nn.Module):
def forward(self, input_data, values):
return torch.searchsorted(input_data, values, side=side)
Expand All @@ -13676,16 +13687,22 @@ def forward(self, input_data, values):
self.run_compare_torch(
(input_data, values),
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
)

@pytest.mark.parametrize(
"compute_unit, backend, values_shape, side",
itertools.product(compute_units, backends, [(2, 1), (2, 10)], [None, "left", "right"]),
"compute_unit, backend, frontend, values_shape, side",
itertools.product(
compute_units, backends, frontends, [(2, 1), (2, 10)], [None, "left", "right"]
),
)
def test_searchsorted_stress(self, compute_unit, backend, values_shape, side):
def test_searchsorted_stress(self, compute_unit, backend, frontend, values_shape, side):
if frontend == TorchFrontend.EXECUTORCH:
pytest.skip("torch._ops.aten.searchsorted.Tensor is not Aten Canonical")

class TestModel(nn.Module):
def forward(self, input_data, values):
return torch.searchsorted(input_data, values, side=side)
Expand All @@ -13695,6 +13712,7 @@ def forward(self, input_data, values):
self.run_compare_torch(
(input_data, values),
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
Expand All @@ -13703,16 +13721,17 @@ def forward(self, input_data, values):

class TestOneHot(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, num_classes, rank",
itertools.product(compute_units, backends, range(1, 5), range(1, 5)),
"compute_unit, backend, frontend, num_classes, rank",
itertools.product(compute_units, backends, frontends, range(1, 5), range(1, 5)),
)
def test_one_hot(self, compute_unit, backend, num_classes, rank):
def test_one_hot(self, compute_unit, backend, frontend, num_classes, rank):
model = ModuleWrapper(function=torch.nn.functional.one_hot, kwargs={"num_classes": num_classes}).eval()
shape = torch.randint(1, 10, (rank,)).tolist()
labels = torch.randint(0, num_classes, shape)
self.run_compare_torch(
torch.LongTensor(labels),
model,
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
Expand Down

0 comments on commit 32ebc55

Please sign in to comment.