Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Illustrative; not for merge] How to prefer float16 as the main float type #1802

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions coremltools/converters/mil/frontend/torch/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def convert_const(self):
raise ValueError("unsupported class for {} in PyTorch graph: {}".format(name, type(val)))
if val.dtype == _np.uint8:
val = val.astype(_np.int32)
if val.dtype == _np.float32:
val = val.astype(_np.float16)
const = mb.const(val=val, name=name)
self.context.add(const)

Expand Down Expand Up @@ -263,16 +265,13 @@ def convert(self):
self.graph.inputs.keys(), ssa_func_inputs.keys()
):
input_var = ssa_func.inputs[users_name]
if (types.is_tensor(input_var.sym_type) or types.is_scalar(input_var.sym_type)) \
and (input_var.dtype == types.fp16 or input_var.dtype == types.fp64):
# cast the input var to float32
# We need to do this because the type inference is very buggy when started from
# float16/float64 typed inputs. Until that is fixed in the following radar
# we cast all inputs of type float16/float64 to float32 as the first step.
# These casts will later get removed, if compute_precision=Float16 is
# provided, which will cause the FP16ComputePrecision pass to run.
# TODO: remove this when this radar is fixed: rdar://93731970
input_var = mb.cast(x=input_var, dtype="fp32")
self.context.add(input_var, torch_name=internal_name)

self.convert_const()
Expand Down
51 changes: 33 additions & 18 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def _construct_constant(val, name):
if val is None:
return None
else:
if isinstance(val, float):
return mb.const(val=val, name=name)
return mb.const(val=val, name=name)


Expand Down Expand Up @@ -454,7 +456,7 @@ def norm(context, node):
def _vector_norm(x, order, dim, keep_dims, name):
if order.val == 0:
# sum(x!=0)
x = mb.cast(x=x, dtype="fp32")
x = mb.cast(x=x, dtype="fp16")
temp = mb.not_equal(x=x, y=0.)
temp = mb.cast(x=temp, dtype='int32')
temp = mb.reduce_sum(x=temp, axes=dim, keep_dims=keep_dims, name=name)
Expand Down Expand Up @@ -1070,8 +1072,8 @@ def linspace(context, node):
start = inputs[0]
end = inputs[1]
nums = inputs[2]
start = mb.cast(x=start, dtype="fp32")
end = mb.cast(x=end, dtype="fp32")
start = mb.cast(x=start, dtype="fp16")
end = mb.cast(x=end, dtype="fp16")

if start.can_be_folded_to_const() and end.can_be_folded_to_const() and nums.can_be_folded_to_const():
start_val = start.val
Expand All @@ -1092,8 +1094,8 @@ def linspace(context, node):
# step = (end - start) / (nums - 1)
x = mb.sub(x=end, y=start)
y = mb.sub(x=nums, y=1)
x = mb.cast(x=x, dtype="fp32")
y = mb.cast(x=y, dtype="fp32")
x = mb.cast(x=x, dtype="fp16")
y = mb.cast(x=y, dtype="fp16")
step = mb.real_div(x=x, y=y)

# Note that the range_1d op excluded the end point,
Expand Down Expand Up @@ -1351,8 +1353,8 @@ def div(context, node):
# e.g.:
# values before trunc: [2.6, -3.4, -3.6]
# values after trunc: [2, -3, -3]
x = mb.cast(x=inputs[0], dtype="fp32")
y = mb.cast(x=inputs[1], dtype="fp32")
x = mb.cast(x=inputs[0], dtype="fp16")
y = mb.cast(x=inputs[1], dtype="fp16")
z = mb.real_div(x=x, y=y)
s = mb.sign(x=z)
all_positive = mb.mul(x=z, y=s)
Expand All @@ -1363,8 +1365,8 @@ def div(context, node):
'rounding mode "{}" not supported in the "div" op'.format(rounding_mode)
)
else:
x = mb.cast(x=inputs[0], dtype="fp32")
y = mb.cast(x=inputs[1], dtype="fp32")
x = mb.cast(x=inputs[0], dtype="fp16")
y = mb.cast(x=inputs[1], dtype="fp16")
res = mb.real_div(x=x, y=y, name=node.name)
context.add(res)

Expand All @@ -1374,7 +1376,7 @@ def floor_divide(context, node):
inputs = _get_inputs(context, node, expected=2)
div_res = mb.floor_div(x=inputs[0], y=inputs[1])
# Pytorch's floor_divide always returns fp32, even if the inputs are int
res = mb.cast(x=div_res, dtype='fp32', name=node.name)
res = mb.cast(x=div_res, dtype='fp16', name=node.name)
context.add(res)


Expand Down Expand Up @@ -1440,7 +1442,7 @@ def mean(context, node):
if types.is_bool(x.dtype):
# TODO: In the future when MIL op supports bool, we need to use curr_opset_version to decide
# if we want to cast or not.
x = mb.cast(x=x, dtype="fp32")
x = mb.cast(x=x, dtype="fp16")
kwargs = {"x": x, "name": node.name}

# @axes is optional, so omit if None.
Expand Down Expand Up @@ -1790,7 +1792,7 @@ def group_norm(context, node):

x = mb.reshape(x=x, shape=new_shape)
mean = mb.reduce_mean(x=x, axes=axes_, keep_dims=True)
var = _std(x,axes_,True,False,eps.val)
var = _std(x,axes_,True,False,eps.val.astype(_np.dtype(f'float{x.dtype.get_bitwidth()}')))
x = mb.sub(x=x,y=mean)
x = mb.real_div(x=x,y=var)
x = mb.reshape(x=x, shape=input_shape)
Expand Down Expand Up @@ -2947,6 +2949,19 @@ def upsample_nearest2d(context, node):
else:
raise ValueError("Failed to infer scale factors from inputs.")

# CoreML only supports upsampling int32 or float32.
# pixel-doubling is a common use-case, so prefer integer if it looks close enough could be the intention.
if scales_h.dtype == _np.float16:
scales_h_int32 = scales_h.astype(_np.int32)
scales_h = scales_h_int32 if (
_np.allclose(scales_h, scales_h_int32)
) else scales_h.astype(_np.float32)
if scales_w.dtype == _np.float16:
scales_w_int32 = scales_w.astype(_np.int32)
scales_w = scales_w_int32 if (
_np.allclose(scales_w, scales_w_int32)
) else scales_w.astype(_np.float32)

upsample_nearest2d = mb.upsample_nearest_neighbor(
x=_input,
scale_factor_height=scales_h,
Expand Down Expand Up @@ -3593,8 +3608,8 @@ def new_full(context, node):
@register_torch_op
def randint(context, node):
inputs = _get_inputs(context, node, expected=8)
low = mb.cast(x=inputs[0], dtype="fp32")
high = mb.cast(x=inputs[1], dtype="fp32")
low = mb.cast(x=inputs[0], dtype="fp16")
high = mb.cast(x=inputs[1], dtype="fp16")
shape = inputs[2]
rand_uniform = mb.random_uniform(shape=shape, low=low, high=high)
rand_int = mb.cast(x=rand_uniform, dtype="int32", name=node.name)
Expand Down Expand Up @@ -4067,9 +4082,9 @@ def arange(context, node):
int_step = isinstance(step, int) or types.is_int(step.dtype)

if int_start != int_end or int_start != int_step:
start = mb.cast(x=start, dtype="fp32")
end = mb.cast(x=end, dtype="fp32")
step = mb.cast(x=step, dtype="fp32")
start = mb.cast(x=start, dtype="fp16")
end = mb.cast(x=end, dtype="fp16")
step = mb.cast(x=step, dtype="fp16")
res = mb.range_1d(start=start, end=end, step=step, name=node.name)
context.add(res)

Expand All @@ -4086,7 +4101,7 @@ def masked_fill(context, node):
if types.is_int(value.dtype):
# @mb.fill cannot handle value with dtype integer
# so we cast the value.
value = mb.cast(x=value, dtype="fp32")
value = mb.cast(x=value, dtype="fp16")

if not types.is_bool(mask.dtype):
# cond must be bool type
Expand Down
8 changes: 4 additions & 4 deletions coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def value_inference(self):
def _get_type_val(self, value):

if isinstance(value, (float, np.float64)):
value = np.float32(value)
value = np.float16(value)
elif isinstance(value, bool):
pass
elif isinstance(value, (int, np.int64)):
Expand All @@ -176,11 +176,11 @@ def _get_type_val(self, value):
value = value.astype(np.int32)


# For the float type, we use float32 by default
# For the float type, we use float16 by default
elif value.dtype == np.float64:
msg = "Downcast const op {} data fp64 as fp32".format(self.name)
msg = "Downcast const op {} data fp64 as fp16".format(self.name)
logger.debug(msg)
value = value.astype(np.float32)
value = value.astype(np.float16)

elif isinstance(value, mil_list):
# if val that was passed in is of type mil_list, which is just a wrapper on top of python list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _apply(passes, name="common"):
"common::noop_elimination",
"common::fuse_matmul_weight_bias",
"common::fuse_linear_bias",
"common::fuse_gelu_tanh_approximation",
# "common::fuse_gelu_tanh_approximation",
"common::fuse_gelu_exact",
"common::fuse_leaky_relu",
"common::rank0_expand_dims_swap",
Expand All @@ -73,7 +73,7 @@ def _apply(passes, name="common"):
"common::fuse_conv_batchnorm", # In some cases, we need to run conv / batch_norm fusion again after the fuse_conv_scale and fuse_conv_bias passes
"common::detect_concat_interleave",
"common::concat_to_pixel_shuffle", # should come after detect_concat_interleave and after replace_stack_reshape
"common::fuse_prelu", # reduce_transpose pass should run before and after this pass (the one after will be run during the cleanup passes stage)
# "common::fuse_prelu", # reduce_transpose pass should run before and after this pass (the one after will be run during the cleanup passes stage)
"common::prelu_to_lrelu",
"common::merge_consecutive_relus",
# "remove_redundant_ops" pass should be applied towards the end, once other graph passes have done their optimizations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def _try_to_transform(mul_op, add_op, block):
out_name = add_op.outputs[0].name
x = mb.batch_norm(
x=non_const_input_mul,
mean=np.zeros((C,), np.float32),
variance=np.ones((C,), np.float32),
mean=np.zeros((C,), np.float16),
variance=np.ones((C,), np.float16),
gamma=np.squeeze(gamma),
beta=np.squeeze(beta),
name=out_name,
Expand Down
10 changes: 4 additions & 6 deletions coremltools/converters/mil/mil/types/type_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def np_dtype_to_py_type(np_dtype):
return int
if np_dtype in [bool, np.bool_]:
return bool
if np_dtype in [np.float32, np.float64]:
if np_dtype in [np.float16, np.float32, np.float64]:
return float
if np_dtype in [np.complex64, np.complex128]:
return complex
Expand Down Expand Up @@ -304,11 +304,9 @@ def numpy_type_to_builtin_type(nptype):
elif np.issubclass_(nptype, np.object_):
# symbolic shape is considered int32
return types_int32
elif np.issubclass_(nptype, np.float16):
elif np.issubclass_(nptype, (np.float16, np.half)) or nptype == float:
return types_fp16
elif (
np.issubclass_(nptype, (np.float32, np.single)) or nptype == float
):
elif np.issubclass_(nptype, (np.float32, np.single)):
return types_fp32
elif np.issubclass_(nptype, (np.float64, np.double)):
return types_fp64
Expand Down Expand Up @@ -337,7 +335,7 @@ def type_to_builtin_type(type):
elif np.issubclass_(type, str):
return types_str
elif np.issubclass_(type, float):
return types_fp32
return types_fp16
elif np.issubclass_(type, complex):
return types_complex64
else:
Expand Down