Skip to content

Commit

Permalink
feat: add compression metadata flatbuffer schema and tests
Browse files Browse the repository at this point in the history
Add a flatbuffer schema for describing compressed models.
Flatbuffers with this schema are to be used as the value in a
.tflite model flatbuffer metadata field, and contain the extra
information necessary to describe a compressed model.

Include tests to ensure basic functionality and demonstrate
integration with C++, Python, and Bazel.

BUG=tensorflow#2636
  • Loading branch information
rkuester committed Nov 15, 2024
1 parent 8ff28fa commit 16e101e
Show file tree
Hide file tree
Showing 8 changed files with 782 additions and 0 deletions.
97 changes: 97 additions & 0 deletions tensorflow/lite/micro/compression/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
load(
"//tensorflow/lite/micro:build_def.bzl",
"tflm_cc_library",
"tflm_cc_test",
)
load(
"@flatbuffers//:build_defs.bzl",
"flatbuffer_cc_library",
"flatbuffer_py_library",
)
load("@rules_python//python:defs.bzl", "py_test")
load("@tflm_pip_deps//:requirements.bzl", "requirement")

package(
default_visibility = [
"//visibility:public",
],
)

flatbuffer_cc_library(
# Generates the header-only library "metadata_generated.h", used to read
# the metadata flatbuffer.
name = "metadata_cc",
srcs = ["metadata.fbs"],
)

tflm_cc_library(
# The header-only library generated by flatc in ":metadata_cc" is saved to
# the source tree and comitted to git as "metadata_saved.h", which is used
# by code which builds via the Make build system, which has no means of
# generating the header on the fly. Code which builds via both bazel and
# Make should #include the saved header and use this target in its bazel
# BUILD deps. Code built exclusively via bazel would typically depend
# directly on ":metadata_cc", which would generate a header from the schema
# on the fly, during the build.
#
# When the schema definition "metadata.fbs" is changed, this saved header
# should be updated by running the script "./metadata_saved_update.sh",
# outside of bazel (because bazel cannot modify the source tree). The
# script regenerates the header from the schema and copies it to the source
# tree as "metadata_saved.h".
#
# Comitting the generated file risks inconsistency between the schema and
# the saved header, so consistency ensured by the unit test
# ":metadata_saved_test".
#
name = "metadata_saved",
hdrs = ["metadata_saved.h"],
)

sh_test(
# Ensures consistency bewteen the schema and the saved generated header.
# Fails if they mismatch, in which case, ./metadata_saved_update.sh should
# be run. See :metadata_saved above.
name = "metadata_saved_test",
size = "small",
srcs = ["metadata_saved_test.sh"],
args = [
"$(location metadata_saved.h)",
"$(location :metadata_cc_srcs)",
],
data = [
"metadata_saved.h",
":metadata_cc_srcs",
],
)

tflm_cc_test(
name = "metadata_test_cc",
size = "small",
srcs = ["metadata_test.cc"],
deps = [
":metadata_saved",
"//tensorflow/lite/micro:hexdump",
"//tensorflow/lite/micro/testing:micro_test",
"@flatbuffers//:runtime_cc",
],
)

flatbuffer_py_library(
# Generates the Python module "metadata_py_generated", used to read the
# metadata flatbuffer.
name = "metadata_py",
srcs = ["metadata.fbs"],
)

py_test(
name = "metadata_test_py",
size = "small",
srcs = ["metadata_test.py"],
main = "metadata_test.py",
deps = [
"metadata_py",
"@flatbuffers//:runtime_py",
requirement("hexdump"),
],
)
56 changes: 56 additions & 0 deletions tensorflow/lite/micro/compression/metadata.fbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright 2024 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

namespace tflite.micro.compression;

table Metadata {
// Compression data root, to be used in a tflite.Model.metadata field with
// the key "COMPRESSION_METADATA".

schema_version:int = 1;
// ^ Incremented whenever there are backward-incompatible changes. Code
// should accept models with versions less than or equal the version
// for which the code is built. I.e., code should accept older models,
// but not necessarily newer ones.

subgraphs:[Subgraph];
// ^ Compression data indexed by subgraph index.
}

table Subgraph {
// Per-subgraph compression metadata.

lut_tensors:[LutTensor];
// ^ A list of tensors which are compressed using the
// (L)ook-(U)p-(T)able method. The indices of this vector are not
// significant.
}

table LutTensor {
// Look-Up-Table Tensor: a tensor representation where elements are
// compressed into indices into a table of values. The indices are unsigned
// integers, index_bitwidth-wide, in big-endian bit order, packed into the
// buffer identified by the corresponding tflite.Tensor's buffer field. The
// values are located in a newly-created buffer, encoded according to the
// tflite.Tensor.type. Tensors with multiple channels have distinct value
// tables for each channel, typically along their quantization axis,
// concatenated one after another. An element's index must be looked up in
// the value table corresponding to its channel.

tensor:int; // index of the corresponding tflite.Tensor
value_buffer:uint; // index of the buffer containing LUT values
index_bitwidth:uint8; // bit-width of LUT indexes
}

root_type Metadata;
Loading

0 comments on commit 16e101e

Please sign in to comment.