Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(compression): implement tensor decompression in op concatenation #3014

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 62 additions & 50 deletions tensorflow/lite/micro/kernels/concatenation.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,13 @@ constexpr int kOutputTensor = 0;

struct OpData {
ConcatenationParams params;

#ifdef USE_TFLM_COMPRESSION

// scratch buffers for compressed tensors
int scratch_indices[kMaxInputNum];

#endif // USE_TFLM_COMPRESSION
};

// Handles negative axis index, coerces to positive index value.
Expand All @@ -52,8 +59,6 @@ inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) {
inline void GetAllInputTensorShapes(const TfLiteContext* context,
const TfLiteNode* node,
RuntimeShape all_shapes[kMaxInputNum]) {
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(node != nullptr);
for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
RuntimeShape shape = tflite::micro::GetTensorShape(t);
Expand All @@ -73,12 +78,22 @@ inline void GetShapesPointers(const RuntimeShape* shapes, size_t num,
template <typename T>
inline void GetAllInputTensorData(const TfLiteContext* context,
const TfLiteNode* node,
T* all_data[kMaxInputNum]) {
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(node != nullptr);
const T* all_data[kMaxInputNum]) {
#ifdef USE_TFLM_COMPRESSION
const OpData* data = static_cast<const OpData*>(node->user_data);
MicroContext* micro_context = GetMicroContext(context);
#endif // USE_TFLM_COMPRESSION

for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
#ifdef USE_TFLM_COMPRESSION
const CompressionTensorData* comp_td =
micro_context->GetTensorCompressionData(node, i);
all_data[i] = tflite::micro::GetTensorData<T>(micro_context, t, comp_td,
data->scratch_indices[i]);
#else // USE_TFLM_COMPRESSION
all_data[i] = tflite::micro::GetTensorData<T>(t);
#endif // USE_TFLM_COMPRESSION
}
}

Expand All @@ -88,16 +103,17 @@ void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
RuntimeShape inputs_shape[kMaxInputNum];
const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
const data_type* inputs_data[kMaxInputNum];
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(node != nullptr);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData* data = static_cast<const OpData*>(node->user_data);
GetAllInputTensorShapes(context, node, inputs_shape);
GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
GetAllInputTensorData(context, node, inputs_data);

TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);

TFLITE_DCHECK(node->user_data != nullptr);
const OpData* data = static_cast<const OpData*>(node->user_data);

reference_ops::Concatenation(data->params, inputs_shape_ptr, inputs_data,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<data_type>(output));
Expand Down Expand Up @@ -126,7 +142,6 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteType output_type = output_tensor->type;

micro_context->DeallocateTempTfLiteTensor(input_tensor);
micro_context->DeallocateTempTfLiteTensor(output_tensor);

// Check activation and input type
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
Expand All @@ -136,16 +151,22 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
input_type == kTfLiteInt64 || input_type == kTfLiteBool);

// Output type must match input type
TF_LITE_ENSURE_EQ(context, output_type, input_type);
TF_LITE_ENSURE_TYPES_EQ(context, output_type, input_type);

// This implementation does not support large number of input tensors
const int num_inputs = NumInputs(node);
TF_LITE_ENSURE(context, num_inputs <= kMaxInputNum);

// Shapes with dimensions >4 are not yet supported with static allocation.
// Calculate OpData.
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);

// Shapes with dimensions > kMaxSmallSize are not yet supported with static
// allocation.
for (int i = 0; i < num_inputs; ++i) {
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, i);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, input_type);
int num_dimensions = NumDimensions(input);

if (num_dimensions > RuntimeShape::kMaxSmallSize) {
Expand All @@ -155,62 +176,53 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
RuntimeShape::kMaxSmallSize, num_dimensions);
return kTfLiteError;
}

if (input_type == kTfLiteInt8) {
// Make sure there is no re-scaling needed for Int8 quantized kernel. This
// is a restriction we introduced to Int8 kernels.
TF_LITE_ENSURE_EQ(context, static_cast<double>(input->params.scale),
static_cast<double>(output_tensor->params.scale));
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
output_tensor->params.zero_point);
} else if (input_type == kTfLiteInt16) {
// Make sure that all Int16 inputs have a null zero-point.
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why the scale is not checked in the int16 case? AFAICS the (at least the reference) concatenation kernel does not take into account any scaling/quantization at all.

inline void Concatenation(const ConcatenationParams& params,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your conclusion, however this is a copy-paste of the TfLite code.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, also is there are reason why it has to be zero? Seems that concatenation in general shouldn't care about quantization parameters, as long as they match in and out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tinskip This same check is performed by the LiteRT (TfLite) reference implementation.


#ifdef USE_TFLM_COMPRESSION

// Compression scratch buffers.
// These will only be allocated if the tensor is compressed.
data->scratch_indices[i] =
micro_context->AllocateDecompressionScratchBuffer(node, i);

#endif // USE_TFLM_COMPRESSION

micro_context->DeallocateTempTfLiteTensor(input);
}

// Calculate OpData.
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);

TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
if (input_type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output_tensor->params.zero_point, 0);
}

switch (output_type) { // Already know in/outtypes are same.
case kTfLiteBool:
case kTfLiteFloat32:
case kTfLiteInt8:
case kTfLiteInt16:
case kTfLiteInt32:
case kTfLiteInt64: {
data->params.axis = CalculatePositiveAxis(params->axis, output);
data->params.inputs_count = node->inputs->size;
break;
}
case kTfLiteInt8: {
data->params.axis = CalculatePositiveAxis(params->axis, output);
data->params.axis = CalculatePositiveAxis(params->axis, output_tensor);
data->params.inputs_count = node->inputs->size;

float* input_scales =
reinterpret_cast<float*>(context->AllocatePersistentBuffer(
context, node->inputs->size * sizeof(float)));

int32_t* input_zero_points =
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
context, node->inputs->size * sizeof(int32_t)));

// Allocate persistent scale and zeropoint buffers.
// Store input scale and zero point values in OpParams:
for (int i = 0; i < node->inputs->size; ++i) {
TfLiteTensor* t = micro_context->AllocateTempInputTensor(node, i);
TF_LITE_ENSURE(context, t != nullptr);
input_scales[i] = t->params.scale;
input_zero_points[i] = t->params.zero_point;
micro_context->DeallocateTempTfLiteTensor(t);
}

data->params.input_scale = input_scales;
data->params.input_zeropoint = input_zero_points;
data->params.output_zeropoint = output->params.zero_point;
data->params.output_scale = output->params.scale;
break;
suleshahid marked this conversation as resolved.
Show resolved Hide resolved
}
default:
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
MicroPrintf("Op Concatenation does not currently support type '%s'.",
TfLiteTypeGetName(output_type));
return kTfLiteError;
}

micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(output_tensor);

return kTfLiteOk;
}
Expand Down
Loading
Loading