From 9b4e1064541c8fee85e19a8b51aceea3fa1426bb Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sun, 27 Nov 2022 01:10:50 +0000 Subject: [PATCH 1/2] switch to fp16 as the go-to float dtype (i.e. how to infer a python float literal, or how to convert ints to floats), to simplify the conversion process (rather than starting in fp32, spraying fp16 casts everywhere, and trying to remove them during conversion). --- .../mil/frontend/torch/converter.py | 5 +-- .../converters/mil/frontend/torch/ops.py | 38 ++++++++++--------- .../mil/mil/ops/defs/iOS15/control_flow.py | 8 ++-- .../mil/passes/apply_common_pass_pipeline.py | 4 +- .../passes/elementwise_batchnorm_fusion.py | 4 +- .../converters/mil/mil/types/type_mapping.py | 10 ++--- 6 files changed, 34 insertions(+), 35 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/converter.py b/coremltools/converters/mil/frontend/torch/converter.py index 27cf9ada2..4796d5bf2 100644 --- a/coremltools/converters/mil/frontend/torch/converter.py +++ b/coremltools/converters/mil/frontend/torch/converter.py @@ -232,6 +232,8 @@ def convert_const(self): raise ValueError("unsupported class for {} in PyTorch graph: {}".format(name, type(val))) if val.dtype == _np.uint8: val = val.astype(_np.int32) + if val.dtype == _np.float32: + val = val.astype(_np.float16) const = mb.const(val=val, name=name) self.context.add(const) @@ -263,8 +265,6 @@ def convert(self): self.graph.inputs.keys(), ssa_func_inputs.keys() ): input_var = ssa_func.inputs[users_name] - if (types.is_tensor(input_var.sym_type) or types.is_scalar(input_var.sym_type)) \ - and (input_var.dtype == types.fp16 or input_var.dtype == types.fp64): # cast the input var to float32 # We need to do this because the type inference is very buggy when started from # float16/float64 typed inputs. Until that is fixed in the following radar @@ -272,7 +272,6 @@ def convert(self): # These casts will later get removed, if compute_precision=Float16 is # provided, which will cause the FP16ComputePrecision pass to run. # TODO: remove this when this radar is fixed: rdar://93731970 - input_var = mb.cast(x=input_var, dtype="fp32") self.context.add(input_var, torch_name=internal_name) self.convert_const() diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 370f4401a..45dcbaea5 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -241,6 +241,8 @@ def _construct_constant(val, name): if val is None: return None else: + if isinstance(val, float): + return mb.const(val=val, name=name) return mb.const(val=val, name=name) @@ -454,7 +456,7 @@ def norm(context, node): def _vector_norm(x, order, dim, keep_dims, name): if order.val == 0: # sum(x!=0) - x = mb.cast(x=x, dtype="fp32") + x = mb.cast(x=x, dtype="fp16") temp = mb.not_equal(x=x, y=0.) temp = mb.cast(x=temp, dtype='int32') temp = mb.reduce_sum(x=temp, axes=dim, keep_dims=keep_dims, name=name) @@ -1070,8 +1072,8 @@ def linspace(context, node): start = inputs[0] end = inputs[1] nums = inputs[2] - start = mb.cast(x=start, dtype="fp32") - end = mb.cast(x=end, dtype="fp32") + start = mb.cast(x=start, dtype="fp16") + end = mb.cast(x=end, dtype="fp16") if start.can_be_folded_to_const() and end.can_be_folded_to_const() and nums.can_be_folded_to_const(): start_val = start.val @@ -1092,8 +1094,8 @@ def linspace(context, node): # step = (end - start) / (nums - 1) x = mb.sub(x=end, y=start) y = mb.sub(x=nums, y=1) - x = mb.cast(x=x, dtype="fp32") - y = mb.cast(x=y, dtype="fp32") + x = mb.cast(x=x, dtype="fp16") + y = mb.cast(x=y, dtype="fp16") step = mb.real_div(x=x, y=y) # Note that the range_1d op excluded the end point, @@ -1351,8 +1353,8 @@ def div(context, node): # e.g.: # values before trunc: [2.6, -3.4, -3.6] # values after trunc: [2, -3, -3] - x = mb.cast(x=inputs[0], dtype="fp32") - y = mb.cast(x=inputs[1], dtype="fp32") + x = mb.cast(x=inputs[0], dtype="fp16") + y = mb.cast(x=inputs[1], dtype="fp16") z = mb.real_div(x=x, y=y) s = mb.sign(x=z) all_positive = mb.mul(x=z, y=s) @@ -1363,8 +1365,8 @@ def div(context, node): 'rounding mode "{}" not supported in the "div" op'.format(rounding_mode) ) else: - x = mb.cast(x=inputs[0], dtype="fp32") - y = mb.cast(x=inputs[1], dtype="fp32") + x = mb.cast(x=inputs[0], dtype="fp16") + y = mb.cast(x=inputs[1], dtype="fp16") res = mb.real_div(x=x, y=y, name=node.name) context.add(res) @@ -1374,7 +1376,7 @@ def floor_divide(context, node): inputs = _get_inputs(context, node, expected=2) div_res = mb.floor_div(x=inputs[0], y=inputs[1]) # Pytorch's floor_divide always returns fp32, even if the inputs are int - res = mb.cast(x=div_res, dtype='fp32', name=node.name) + res = mb.cast(x=div_res, dtype='fp16', name=node.name) context.add(res) @@ -1440,7 +1442,7 @@ def mean(context, node): if types.is_bool(x.dtype): # TODO: In the future when MIL op supports bool, we need to use curr_opset_version to decide # if we want to cast or not. - x = mb.cast(x=x, dtype="fp32") + x = mb.cast(x=x, dtype="fp16") kwargs = {"x": x, "name": node.name} # @axes is optional, so omit if None. @@ -1790,7 +1792,7 @@ def group_norm(context, node): x = mb.reshape(x=x, shape=new_shape) mean = mb.reduce_mean(x=x, axes=axes_, keep_dims=True) - var = _std(x,axes_,True,False,eps.val) + var = _std(x,axes_,True,False,eps.val.astype(_np.dtype(f'float{x.dtype.get_bitwidth()}'))) x = mb.sub(x=x,y=mean) x = mb.real_div(x=x,y=var) x = mb.reshape(x=x, shape=input_shape) @@ -3593,8 +3595,8 @@ def new_full(context, node): @register_torch_op def randint(context, node): inputs = _get_inputs(context, node, expected=8) - low = mb.cast(x=inputs[0], dtype="fp32") - high = mb.cast(x=inputs[1], dtype="fp32") + low = mb.cast(x=inputs[0], dtype="fp16") + high = mb.cast(x=inputs[1], dtype="fp16") shape = inputs[2] rand_uniform = mb.random_uniform(shape=shape, low=low, high=high) rand_int = mb.cast(x=rand_uniform, dtype="int32", name=node.name) @@ -4067,9 +4069,9 @@ def arange(context, node): int_step = isinstance(step, int) or types.is_int(step.dtype) if int_start != int_end or int_start != int_step: - start = mb.cast(x=start, dtype="fp32") - end = mb.cast(x=end, dtype="fp32") - step = mb.cast(x=step, dtype="fp32") + start = mb.cast(x=start, dtype="fp16") + end = mb.cast(x=end, dtype="fp16") + step = mb.cast(x=step, dtype="fp16") res = mb.range_1d(start=start, end=end, step=step, name=node.name) context.add(res) @@ -4086,7 +4088,7 @@ def masked_fill(context, node): if types.is_int(value.dtype): # @mb.fill cannot handle value with dtype integer # so we cast the value. - value = mb.cast(x=value, dtype="fp32") + value = mb.cast(x=value, dtype="fp16") if not types.is_bool(mask.dtype): # cond must be bool type diff --git a/coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py b/coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py index 73552f6f2..84c61241f 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py @@ -160,7 +160,7 @@ def value_inference(self): def _get_type_val(self, value): if isinstance(value, (float, np.float64)): - value = np.float32(value) + value = np.float16(value) elif isinstance(value, bool): pass elif isinstance(value, (int, np.int64)): @@ -176,11 +176,11 @@ def _get_type_val(self, value): value = value.astype(np.int32) - # For the float type, we use float32 by default + # For the float type, we use float16 by default elif value.dtype == np.float64: - msg = "Downcast const op {} data fp64 as fp32".format(self.name) + msg = "Downcast const op {} data fp64 as fp16".format(self.name) logger.debug(msg) - value = value.astype(np.float32) + value = value.astype(np.float16) elif isinstance(value, mil_list): # if val that was passed in is of type mil_list, which is just a wrapper on top of python list diff --git a/coremltools/converters/mil/mil/passes/apply_common_pass_pipeline.py b/coremltools/converters/mil/mil/passes/apply_common_pass_pipeline.py index 1cebc5c8d..f3fe28d8c 100644 --- a/coremltools/converters/mil/mil/passes/apply_common_pass_pipeline.py +++ b/coremltools/converters/mil/mil/passes/apply_common_pass_pipeline.py @@ -51,7 +51,7 @@ def _apply(passes, name="common"): "common::noop_elimination", "common::fuse_matmul_weight_bias", "common::fuse_linear_bias", - "common::fuse_gelu_tanh_approximation", + # "common::fuse_gelu_tanh_approximation", "common::fuse_gelu_exact", "common::fuse_leaky_relu", "common::rank0_expand_dims_swap", @@ -73,7 +73,7 @@ def _apply(passes, name="common"): "common::fuse_conv_batchnorm", # In some cases, we need to run conv / batch_norm fusion again after the fuse_conv_scale and fuse_conv_bias passes "common::detect_concat_interleave", "common::concat_to_pixel_shuffle", # should come after detect_concat_interleave and after replace_stack_reshape - "common::fuse_prelu", # reduce_transpose pass should run before and after this pass (the one after will be run during the cleanup passes stage) + # "common::fuse_prelu", # reduce_transpose pass should run before and after this pass (the one after will be run during the cleanup passes stage) "common::prelu_to_lrelu", "common::merge_consecutive_relus", # "remove_redundant_ops" pass should be applied towards the end, once other graph passes have done their optimizations. diff --git a/coremltools/converters/mil/mil/passes/elementwise_batchnorm_fusion.py b/coremltools/converters/mil/mil/passes/elementwise_batchnorm_fusion.py index 7a67269cf..404d2a5d5 100644 --- a/coremltools/converters/mil/mil/passes/elementwise_batchnorm_fusion.py +++ b/coremltools/converters/mil/mil/passes/elementwise_batchnorm_fusion.py @@ -73,8 +73,8 @@ def _try_to_transform(mul_op, add_op, block): out_name = add_op.outputs[0].name x = mb.batch_norm( x=non_const_input_mul, - mean=np.zeros((C,), np.float32), - variance=np.ones((C,), np.float32), + mean=np.zeros((C,), np.float16), + variance=np.ones((C,), np.float16), gamma=np.squeeze(gamma), beta=np.squeeze(beta), name=out_name, diff --git a/coremltools/converters/mil/mil/types/type_mapping.py b/coremltools/converters/mil/mil/types/type_mapping.py index e686b8bed..5f8681030 100644 --- a/coremltools/converters/mil/mil/types/type_mapping.py +++ b/coremltools/converters/mil/mil/types/type_mapping.py @@ -102,7 +102,7 @@ def np_dtype_to_py_type(np_dtype): return int if np_dtype in [bool, np.bool_]: return bool - if np_dtype in [np.float32, np.float64]: + if np_dtype in [np.float16, np.float32, np.float64]: return float if np_dtype in [np.complex64, np.complex128]: return complex @@ -304,11 +304,9 @@ def numpy_type_to_builtin_type(nptype): elif np.issubclass_(nptype, np.object_): # symbolic shape is considered int32 return types_int32 - elif np.issubclass_(nptype, np.float16): + elif np.issubclass_(nptype, (np.float16, np.half)) or nptype == float: return types_fp16 - elif ( - np.issubclass_(nptype, (np.float32, np.single)) or nptype == float - ): + elif np.issubclass_(nptype, (np.float32, np.single)): return types_fp32 elif np.issubclass_(nptype, (np.float64, np.double)): return types_fp64 @@ -337,7 +335,7 @@ def type_to_builtin_type(type): elif np.issubclass_(type, str): return types_str elif np.issubclass_(type, float): - return types_fp32 + return types_fp16 elif np.issubclass_(type, complex): return types_complex64 else: From 0fa03a0c4c72f786068e4c8d8583e236fedf1b04 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sun, 12 Mar 2023 00:34:02 +0000 Subject: [PATCH 2/2] upsample_nearest_neighbor: cast float16 scale factors to a supported dtype --- coremltools/converters/mil/frontend/torch/ops.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 45dcbaea5..fa0ac485e 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -2949,6 +2949,19 @@ def upsample_nearest2d(context, node): else: raise ValueError("Failed to infer scale factors from inputs.") + # CoreML only supports upsampling int32 or float32. + # pixel-doubling is a common use-case, so prefer integer if it looks close enough could be the intention. + if scales_h.dtype == _np.float16: + scales_h_int32 = scales_h.astype(_np.int32) + scales_h = scales_h_int32 if ( + _np.allclose(scales_h, scales_h_int32) + ) else scales_h.astype(_np.float32) + if scales_w.dtype == _np.float16: + scales_w_int32 = scales_w.astype(_np.int32) + scales_w = scales_w_int32 if ( + _np.allclose(scales_w, scales_w_int32) + ) else scales_w.astype(_np.float32) + upsample_nearest2d = mb.upsample_nearest_neighbor( x=_input, scale_factor_height=scales_h,