Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip casting model inputs to fp32 if weights and inputs are all fp16 #2274

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion coremltools/converters/mil/frontend/torch/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,23 @@ def __init__(
self.opset_version = _target(opset_version) if opset_version is not None else None
self._prog = mil.Program()

self.src_model_has_all_fp16_weights = False

if isinstance(loaded_model, torch.jit.ScriptModule):
# src_model_has_all_fp16_weights will be True
# if there are more than one trainable layers in the model
# and if all those trainable layers have the fp16 dtype
# eg: if pytorch_model.half() has been explicitly used.
num_trainable_layers = 0
num_trainable_fp16_layers = 0
for param in loaded_model.parameters():
if param.requires_grad:
num_trainable_layers += 1
if param.dtype == torch.float16:
num_trainable_fp16_layers += 1
if num_trainable_layers > 0:
self.src_model_has_all_fp16_weights = num_trainable_layers == num_trainable_fp16_layers

self.context = TranscriptionContext(frontend=TorchFrontend.TORCHSCRIPT)
self.graph = InternalTorchIRGraph.from_torchscript(
torchscript=loaded_model, inputs=self.inputs, cut_at_symbols=cut_at_symbols
Expand Down Expand Up @@ -1261,6 +1277,11 @@ def convert(self) -> Program:
user_names = list(ssa_func_inputs.keys())
internal_names = list(self.graph.inputs.keys())
internal_names.extend(user_names[len(internal_names) :])
input_dtypes = []
for torch_name, ssa_name in zip(internal_names, user_names):
input_var = ssa_func.inputs[ssa_name]
input_dtypes.append(input_var.dtype)
all_fp16_inputs = all(x == types.fp16 for x in input_dtypes)
for torch_name, ssa_name in zip(internal_names, user_names):
input_var = ssa_func.inputs[ssa_name]
if self.context.frontend == TorchFrontend.TORCHSCRIPT:
Expand All @@ -1272,7 +1293,7 @@ def convert(self) -> Program:
# So here we perform the "cast input to fp32" step
if (
types.is_tensor(input_var.sym_type) or types.is_scalar(input_var.sym_type)
) and input_var.dtype == types.fp16:
) and input_var.dtype == types.fp16 and not (all_fp16_inputs and self.src_model_has_all_fp16_weights):
# This cast should have placeholder scope
with mb.scope(
ScopeInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,30 @@ def forward(self, x, y):
result[name], expected.detach().numpy(), rtol=rtol, atol=atol
)

@staticmethod
@pytest.mark.parametrize(
"backend",
backends,
)
def test_torch_fp16_model_with_fp16_inputs(torch_model, backend):
if backend[0] == "neuralnetwork":
pytest.skip(
"Input float16 needs target >= iOS16, which doesn't support neuralnetwork."
)
traced_torch_model = torch.jit.trace(torch_model.half(), torch.rand(1, 10).half())
ct.convert(
traced_torch_model,
source="pytorch",
inputs=[
ct.TensorType(
shape=(1, 10),
)
],
outputs=[ct.TensorType(dtype=np.float16)],
convert_to=backend[0],
minimum_deployment_target=ct.target.macOS13,
)


@pytest.fixture
def int32_input_model():
Expand Down