From 5e0b8397f482d7f0b2ace73aa1cc1d39a94dcb80 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 24 Dec 2024 11:00:26 +0000 Subject: [PATCH 1/2] Add 16bit support for unpack, transpose and comparisons operators --- tensorflow/lite/micro/kernels/comparisons.cc | 13 +++++++++++++ tensorflow/lite/micro/kernels/transpose.cc | 8 +++++++- tensorflow/lite/micro/kernels/unpack.cc | 3 +++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/micro/kernels/comparisons.cc b/tensorflow/lite/micro/kernels/comparisons.cc index 69b3c61c32d..e5ef8ead3c4 100644 --- a/tensorflow/lite/micro/kernels/comparisons.cc +++ b/tensorflow/lite/micro/kernels/comparisons.cc @@ -286,6 +286,19 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(input2), output_shape, output_data); break; + case kTfLiteInt16: + requires_broadcast + ? reference_ops::Broadcast4DSlowGreaterWithScaling( + data->params, input1_shape, + tflite::micro::GetTensorData(input1), input2_shape, + tflite::micro::GetTensorData(input2), output_shape, + output_data) + : reference_ops::GreaterWithScaling( + data->params, input1_shape, + tflite::micro::GetTensorData(input1), input2_shape, + tflite::micro::GetTensorData(input2), output_shape, + output_data); + break; default: MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input1->type), input1->type); diff --git a/tensorflow/lite/micro/kernels/transpose.cc b/tensorflow/lite/micro/kernels/transpose.cc index fd17e893937..af8819f478e 100644 --- a/tensorflow/lite/micro/kernels/transpose.cc +++ b/tensorflow/lite/micro/kernels/transpose.cc @@ -103,10 +103,16 @@ TfLiteStatus TransposeEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; + case kTfLiteInt16: + reference_ops::Transpose(params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; default: MicroPrintf( "Type %s is currently not supported by Transpose. " - "Only float32 and int8 is supported", + "Only float32, int8 and int16 is supported", TfLiteTypeGetName(input->type)); return kTfLiteError; } diff --git a/tensorflow/lite/micro/kernels/unpack.cc b/tensorflow/lite/micro/kernels/unpack.cc index 9ce168384a4..bc868e35a37 100644 --- a/tensorflow/lite/micro/kernels/unpack.cc +++ b/tensorflow/lite/micro/kernels/unpack.cc @@ -89,6 +89,9 @@ TfLiteStatus UnpackEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt8: { return UnpackImpl(context, node, input, data->num, data->axis); } + case kTfLiteInt16: { + return UnpackImpl(context, node, input, data->num, data->axis); + } default: { MicroPrintf("Type '%s' is not supported by unpack.", TfLiteTypeGetName(input->type)); From 611d82bdad691440a2528d0a2c2e5730b39ec958 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 24 Dec 2024 11:38:21 +0000 Subject: [PATCH 2/2] Added support for FC16x8 per channel quantization --- .../lite/micro/kernels/fully_connected.cc | 102 +++++++++++++++--- .../micro/kernels/fully_connected_common.cc | 11 +- 2 files changed, 94 insertions(+), 19 deletions(-) diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index 6902728043f..2f117aa4efb 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -238,25 +238,95 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt16: { switch (filter->type) { case kTfLiteInt8: { - tflite::reference_integer_ops::FullyConnected( - FullyConnectedParamsQuantized(data), - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), + if (bias == nullptr || bias->type == kTfLiteInt32) { + data.is_per_channel + ? tflite::reference_integer_ops::FullyConnectedPerChannel( + FullyConnectedParamsQuantized(data), + data.per_channel_output_multiplier, + reinterpret_cast(data.per_channel_output_shift), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), #ifdef USE_TFLM_COMPRESSION - tflite::micro::GetTensorData(micro_context, filter, - weights_comp_td, - data.weights_scratch_index), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData( - micro_context, bias, bias_comp_td, data.bias_scratch_index), + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), #else // USE_TFLM_COMPRESSION - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), #endif // USE_TFLM_COMPRESSION - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)) + : tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else if (bias->type == kTfLiteInt64) { + data.is_per_channel + ? tflite::reference_integer_ops::FullyConnectedPerChannel( + FullyConnectedParamsQuantized(data), + data.per_channel_output_multiplier, + reinterpret_cast(data.per_channel_output_shift), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)) + : tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } break; } default: { diff --git a/tensorflow/lite/micro/kernels/fully_connected_common.cc b/tensorflow/lite/micro/kernels/fully_connected_common.cc index 53709d366bf..858a6ac6f6f 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_common.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_common.cc @@ -95,9 +95,14 @@ TfLiteStatus CalculateOpDataFullyConnected( filter->quantization.params); const int per_channel_quantization_size = affine_quantization->scale->size; - // Currently only Int8 is supported for per channel quantization. - TF_LITE_ENSURE(context, - input->type == kTfLiteInt8 && filter->type != kTfLiteInt4); + // Currently only Int8/Int16 are supported for per channel quantization. + TF_LITE_ENSURE( + context, + (input->type == kTfLiteInt8 && filter->type != kTfLiteInt4) || + (input->type == kTfLiteInt16 && filter->type != kTfLiteInt4)); + + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + per_channel_quantization_size); TF_LITE_ENSURE_EQ( context, per_channel_quantization_size,