Skip to content

Commit

Permalink
[Docs] Update Torch.Export Guides (#2401)
Browse files Browse the repository at this point in the history
* update doc to reflect new torch.export features in coremltools 8.1

* torch.export openelm issue has been fixed in torch 2.5

---------

Co-authored-by: yifan_shen3 <[email protected]>
  • Loading branch information
YifanShenSZ and yifan_shen3 authored Nov 21, 2024
1 parent ff2fef4 commit be29fb9
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 68 deletions.
80 changes: 28 additions & 52 deletions docs-guides/source/convert-a-torchvision-model-from-pytorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,9 @@ mlmodel_from_trace = ct.convert(
classifier_config = ct.ClassifierConfig(class_labels),
compute_units=ct.ComputeUnit.CPU_ONLY,
)
```

Torch.export path does not support image input yet (as of Core ML Tools 8.0), so it still uses tensor input
```python
# Using image_input in the inputs parameter:
# Convert to Core ML using the Unified Conversion API.
mlmodel_from_export = ct.convert(
exported_program,
inputs=[image_input],
classifier_config = ct.ClassifierConfig(class_labels),
compute_units=ct.ComputeUnit.CPU_ONLY,
)
Expand All @@ -152,7 +147,7 @@ Save the ML program using the `.mlpackage` extension. It may also be helpful to

```python
# Save the converted model.
mlmodel_from_trace.save("mobilenet.mlpackage")
mlmodel_from_export.save("mobilenet.mlpackage")
# Print a confirmation message.
print("model converted and saved")
```
Expand Down Expand Up @@ -254,7 +249,7 @@ To get the fields and types used in the model, get the protobuf spec with [get_s

```python
# Get the protobuf spec of the model.
spec = mlmodel_from_trace.get_spec()
spec = mlmodel_from_export.get_spec()
for out in spec.description.output:
if out.type.WhichOneof("Type") == "dictionaryType":
coreml_dict_name = out.name
Expand All @@ -266,60 +261,41 @@ You can now make a prediction with the converted model, using the test image. To

```python
# Make a prediction with the Core ML version of the model.
coreml_out_dict = mlmodel_from_trace.predict({"x": img})
print("coreml predictions: ")
print("top class label: ", coreml_out_dict["classLabel"])

coreml_prob_dict = coreml_out_dict[coreml_dict_name]

values_vector = np.array(list(coreml_prob_dict.values()))
keys_vector = list(coreml_prob_dict.keys())
top_3_indices_coreml = np.argsort(-values_vector)[:3]
for i in range(3):
idx = top_3_indices_coreml[i]
score_value = values_vector[idx]
class_id = keys_vector[idx]
print("class name: {}, raw score value: {}".format(class_id, score_value))
def predict_with_coreml(mlmodel):
coreml_out_dict = mlmodel.predict({"x": img})
print("top class label: ", coreml_out_dict["classLabel"])

coreml_prob_dict = coreml_out_dict[coreml_dict_name]

values_vector = np.array(list(coreml_prob_dict.values()))
keys_vector = list(coreml_prob_dict.keys())
top_3_indices_coreml = np.argsort(-values_vector)[:3]
for i in range(3):
idx = top_3_indices_coreml[i]
score_value = values_vector[idx]
class_id = keys_vector[idx]
print("class name: {}, raw score value: {}".format(class_id, score_value))

print("coreml (converted from torch.jit.trace) predictions: ")
predict_with_coreml(mlmodel_from_trace)

print("coreml (converted from torch.export) predictions: ")
predict_with_coreml(mlmodel_from_export)
```

When you run this example, the output should be something like the following, using the image of a daisy as the input:

```text Output
coreml predictions:
coreml (converted from torch.jit.trace) predictions:
top class label: daisy
class name: daisy, raw score value: 15.8046875
class name: vase, raw score value: 8.4921875
class name: ant, raw score value: 8.2109375
```

The model converted from torch.export will need to use the image tensor same to torch

```python
# Make a prediction with the Core ML version of the model.
coreml_out_dict = mlmodel_from_export.predict({"x": img_torch.detach().numpy()})
print("coreml predictions: ")
print("top class label: ", coreml_out_dict["classLabel"])

coreml_prob_dict = coreml_out_dict[coreml_dict_name]

values_vector = np.array(list(coreml_prob_dict.values()))
keys_vector = list(coreml_prob_dict.keys())
top_3_indices_coreml = np.argsort(-values_vector)[:3]
for i in range(3):
idx = top_3_indices_coreml[i]
score_value = values_vector[idx]
class_id = keys_vector[idx]
print("class name: {}, raw score value: {}".format(class_id, score_value))
```

When you run this example, the output should be something like the following, using the image of a daisy as the input:

```text Output
coreml predictions:
coreml (converted from torch.export) predictions:
top class label: daisy
class name: daisy, raw score value: 15.7890625
class name: vase, raw score value: 8.546875
class name: ant, raw score value: 8.34375
class name: daisy, raw score value: 15.8046875
class name: vase, raw score value: 8.4921875
class name: ant, raw score value: 8.2109375
```

As you can see from the results, the converted model performs very closely to the original model — the raw score values are very similar.
9 changes: 3 additions & 6 deletions docs-guides/source/convert-openelm.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ pip install coremltools

At the time of creating this example, the author environment is
```text Output
torch 2.4.1
transformers 4.45.1
coremltools 8.0
torch 2.5.1
transformers 4.46.3
coremltools 8.1
```

## Import Libraries and Set Up the Model
Expand Down Expand Up @@ -66,9 +66,6 @@ exported_program = torch.export.export(
torch_model,
(example_input_ids,),
dynamic_shapes=dynamic_shapes,
# Because of https://github.com/pytorch/pytorch/issues/133252
# we need to use strict=False until torch 2.5
strict=False,
)
```

Expand Down
2 changes: 1 addition & 1 deletion docs-guides/source/convert-pytorch-workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The conversion from a graph captured via `torch.jit.trace` has been supported fo

The conversion from `torch.export` graph has been newly added to Core ML Tools 8.0.
It is currently in beta state, in line with the export API status in PyTorch.
As of Core ML Tools 8.0, representative models such as MobileBert, ResNet, ViT, [MobileNet](convert-a-torchvision-model-from-pytorch), [DeepLab](convert-a-pytorch-segmentation-model), [OpenELM](convert-openelm) can be converted, and the total PyTorch op translation test coverage is roughly ~60%. You can start trying the torch.export path on your models that are working with torch.jit.trace already, so as to gradually move them to the export path as PyTorch also [moves](https://github.com/pytorch/pytorch/issues/103841#issuecomment-1605017153) its support and development to that path over a period of time. In case you hit issues (e.g. models converted via export path are slower than the ones converted from jit.trace path), please report them on Github.
As of Core ML Tools 8.0, representative models such as MobileBert, ResNet, ViT, [MobileNet](convert-a-torchvision-model-from-pytorch), [DeepLab](convert-a-pytorch-segmentation-model), [OpenELM](convert-openelm) can be converted, and the total PyTorch op translation test coverage is roughly ~70%. You can start trying the torch.export path on your models that are working with torch.jit.trace already, so as to gradually move them to the export path as PyTorch also [moves](https://github.com/pytorch/pytorch/issues/103841#issuecomment-1605017153) its support and development to that path over a period of time. In case you hit issues (e.g. models converted via export path are slower than the ones converted from jit.trace path), please report them on Github.

Now let us take a closer look at how to convert from PyTorch to Core ML through an example.

Expand Down
5 changes: 5 additions & 0 deletions docs-guides/source/flexible-inputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ Core ML preallocates the memory for the default shape, so the first prediction w
For a multi-input model, only one of the inputs can be marked with `EnumeratedShapes`; the rest must have fixed single shapes. If you require multiple inputs to be flexible, set the range for each dimension.
```

```{admonition} Torch.Export Dynamism
If the source PyTorch model is exported by [`torch.export.export`](https://pytorch.org/docs/stable/export.html#torch.export.export), then user will need to [express dynamism in torch.export](https://pytorch.org/docs/stable/export.html#expressing-dynamism), and only the torch.export dynamic dimensions are allowed to have more-than-1 possible sizes, see [Model Exporting](model-exporting).
```

```{eval-rst}
.. index::
single: RangeDim
Expand Down
18 changes: 9 additions & 9 deletions docs-guides/source/model-exporting.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ The recommended way to generate ExportedProgram for your model is to use PyTorch
The conversion from `torch.export` graph has been newly added to Core ML Tools 8.0.
It is currently in beta state, in line with the export API status in PyTorch.
As of Core ML Tools 8.0, representative models such as MobileBert, ResNet, ViT, [MobileNet](convert-a-torchvision-model-from-pytorch), [DeepLab](convert-a-pytorch-segmentation-model), [OpenELM](convert-openelm) can be converted, and the total PyTorch op translation test coverage is roughly ~60%. You can start trying the torch.export path on your models that are working with torch.jit.trace already, so as to gradually move them to the export path as PyTorch also [moves](https://github.com/pytorch/pytorch/issues/103841#issuecomment-1605017153) its support and development to that path over a period of time. In case you hit issues (e.g. models converted via export path are slower than the ones converted from jit.trace path), please report them on Github.
As of Core ML Tools 8.0, representative models such as MobileBert, ResNet, ViT, [MobileNet](convert-a-torchvision-model-from-pytorch), [DeepLab](convert-a-pytorch-segmentation-model), [OpenELM](convert-openelm) can be converted, and the total PyTorch op translation test coverage is roughly ~70%. You can start trying the torch.export path on your models that are working with torch.jit.trace already, so as to gradually move them to the export path as PyTorch also [moves](https://github.com/pytorch/pytorch/issues/103841#issuecomment-1605017153) its support and development to that path over a period of time. In case you hit issues (e.g. models converted via export path are slower than the ones converted from jit.trace path), please report them on Github.
Also, torch.export has limitations, see [here](https://pytorch.org/docs/stable/export.html#limitations-of-torch-export)
```
Expand Down Expand Up @@ -110,11 +110,11 @@ The following example builds a simple model from scratch and exports it to gener
```

## Difference from Tracing
For tracing, `ct.convert` requires the `inputs` arg from user. This is no longer the case for exporting, since the ExportedProgram object carries all name and shape and dtype info.

Subsequently, for exporting, dynamic shape is no longer specified through the `inputs` arg of `ct.convert`. Instead, user will need to [express dynamism in torch.export](https://pytorch.org/docs/stable/export.html#expressing-dynamism), which will then be automatically converted to Core ML [`RangeDim`](https://apple.github.io/coremltools/source/coremltools.converters.mil.input_types.html#rangedim).

As of Core ML Tools 8.0, there are several features that the torch.export conversion path is yet to support, compared to the mature torch.jit.trace path. Such as:
* [EnumeratedShapes](https://apple.github.io/coremltools/source/coremltools.converters.mil.input_types.html#enumeratedshapes)
* [ImageType](https://apple.github.io/coremltools/source/coremltools.converters.mil.input_types.html#coremltools.converters.mil.input_types.ImageType)
* Setting custom names and dtypes for model inputs and outputs
For tracing, [`ct.convert`](https://apple.github.io/coremltools/source/coremltools.converters.convert.html#coremltools.converters._converters_entry.convert) requires the `inputs` arg from user. This is no longer required for exporting, since the ExportedProgram object carries all name and shape and dtype info, so [`TensorType`](https://apple.github.io/coremltools/source/coremltools.converters.mil.input_types.html#tensortype), [`RangeDim`](https://apple.github.io/coremltools/source/coremltools.converters.mil.input_types.html#rangedim), and [`StateType`](https://apple.github.io/coremltools/source/coremltools.converters.mil.input_types.html#statetype) will be automatically created based on ExportedProgram info if `inputs` is abscent. There are 3 cases where `inputs` is still necessary
1. [`ImageType`](https://apple.github.io/coremltools/source/coremltools.converters.mil.input_types.html#coremltools.converters.mil.input_types.ImageType)
2. [`EnumeratedShapes`](https://apple.github.io/coremltools/source/coremltools.converters.mil.input_types.html#enumeratedshapes)
3. Customize name / dtype

Another difference between tracing and exporting is how to create dynamic shapes. Torch.jit.trace simply traces the executed torch ops and does not have the concept of dynamism, so dynamic shapes are specified and propagated in `ct.convert`. Torch.export, however, [rigorously expresses dynamism](https://pytorch.org/docs/stable/export.html#expressing-dynamism), so dynamic shapes are first specified and propagated in torch.export, then when calling `ct.convert`
* If `RangeDim` is desired, then nothing more is needed, since it will be automatically converted from [`torch.export.Dim`](https://pytorch.org/docs/stable/export.html#torch.export.dynamic_shapes.Dim)
* Else if `EnumeratedShapes` are desired, then user will need to specify shape enumeration in `inputs` arg, and only the torch.export dynamic dimensions are allowed to have more-than-1 possible sizes

0 comments on commit be29fb9

Please sign in to comment.