forked from tensorflow/tflite-micro
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add compression metadata flatbuffer schema and tests
Add a flatbuffer schema for describing compressed models. Flatbuffers with this schema are to be used as the value in a .tflite model flatbuffer metadata field, and contain the extra information necessary to describe a compressed model. Include tests to ensure basic functionality and demonstrate integration with C++, Python, and Bazel. BUG=tensorflow#2636
- Loading branch information
Showing
8 changed files
with
782 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
load( | ||
"//tensorflow/lite/micro:build_def.bzl", | ||
"tflm_cc_library", | ||
"tflm_cc_test", | ||
) | ||
load( | ||
"@flatbuffers//:build_defs.bzl", | ||
"flatbuffer_cc_library", | ||
"flatbuffer_py_library", | ||
) | ||
load("@rules_python//python:defs.bzl", "py_test") | ||
load("@tflm_pip_deps//:requirements.bzl", "requirement") | ||
|
||
package( | ||
default_visibility = [ | ||
"//visibility:public", | ||
], | ||
) | ||
|
||
flatbuffer_cc_library( | ||
# Generates the header-only library "metadata_generated.h", used to read | ||
# the metadata flatbuffer. | ||
name = "metadata_cc", | ||
srcs = ["metadata.fbs"], | ||
) | ||
|
||
tflm_cc_library( | ||
# The header-only library generated by flatc in ":metadata_cc" is saved to | ||
# the source tree and comitted to git as "metadata_saved.h", which is used | ||
# by code which builds via the Make build system, which has no means of | ||
# generating the header on the fly. Code which builds via both bazel and | ||
# Make should #include the saved header and use this target in its bazel | ||
# BUILD deps. Code built exclusively via bazel would typically depend | ||
# directly on ":metadata_cc", which would generate a header from the schema | ||
# on the fly, during the build. | ||
# | ||
# When the schema definition "metadata.fbs" is changed, this saved header | ||
# should be updated by running the script "./metadata_saved_update.sh", | ||
# outside of bazel (because bazel cannot modify the source tree). The | ||
# script regenerates the header from the schema and copies it to the source | ||
# tree as "metadata_saved.h". | ||
# | ||
# Comitting the generated file risks inconsistency between the schema and | ||
# the saved header, so consistency ensured by the unit test | ||
# ":metadata_saved_test". | ||
# | ||
name = "metadata_saved", | ||
hdrs = ["metadata_saved.h"], | ||
) | ||
|
||
sh_test( | ||
# Ensures consistency bewteen the schema and the saved generated header. | ||
# Fails if they mismatch, in which case, ./metadata_saved_update.sh should | ||
# be run. See :metadata_saved above. | ||
name = "metadata_saved_test", | ||
size = "small", | ||
srcs = ["metadata_saved_test.sh"], | ||
args = [ | ||
"$(location metadata_saved.h)", | ||
"$(location :metadata_cc_srcs)", | ||
], | ||
data = [ | ||
"metadata_saved.h", | ||
":metadata_cc_srcs", | ||
], | ||
) | ||
|
||
tflm_cc_test( | ||
name = "metadata_test_cc", | ||
size = "small", | ||
srcs = ["metadata_test.cc"], | ||
deps = [ | ||
":metadata_saved", | ||
"//tensorflow/lite/micro:hexdump", | ||
"//tensorflow/lite/micro/testing:micro_test", | ||
"@flatbuffers//:runtime_cc", | ||
], | ||
) | ||
|
||
flatbuffer_py_library( | ||
# Generates the Python module "metadata_py_generated", used to read the | ||
# metadata flatbuffer. | ||
name = "metadata_py", | ||
srcs = ["metadata.fbs"], | ||
) | ||
|
||
py_test( | ||
name = "metadata_test_py", | ||
size = "small", | ||
srcs = ["metadata_test.py"], | ||
main = "metadata_test.py", | ||
deps = [ | ||
"metadata_py", | ||
"@flatbuffers//:runtime_py", | ||
requirement("hexdump"), | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// Copyright 2024 The TensorFlow Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
namespace tflite.micro.compression; | ||
|
||
table Metadata { | ||
// Compression data root, to be used in a tflite.Model.metadata field with | ||
// the key "COMPRESSION_METADATA". | ||
|
||
schema_version:int = 1; | ||
// ^ Incremented whenever there are backward-incompatible changes. Code | ||
// should accept models with versions less than or equal the version | ||
// for which the code is built. I.e., code should accept older models, | ||
// but not necessarily newer ones. | ||
|
||
subgraphs:[Subgraph]; | ||
// ^ Compression data indexed by subgraph index. | ||
} | ||
|
||
table Subgraph { | ||
// Per-subgraph compression metadata. | ||
|
||
lut_tensors:[LutTensor]; | ||
// ^ A list of tensors which are compressed using the | ||
// (L)ook-(U)p-(T)able method. The indices of this vector are not | ||
// significant. | ||
} | ||
|
||
table LutTensor { | ||
// Look-Up-Table Tensor: a tensor representation where elements are | ||
// compressed into indices into a table of values. The indices are unsigned | ||
// integers, index_bitwidth-wide, in big-endian bit order, packed into the | ||
// buffer identified by the corresponding tflite.Tensor's buffer field. The | ||
// values are located in a newly-created buffer, encoded according to the | ||
// tflite.Tensor.type. Tensors with multiple channels have distinct value | ||
// tables for each channel, typically along their quantization axis, | ||
// concatenated one after another. An element's index must be looked up in | ||
// the value table corresponding to its channel. | ||
|
||
tensor:int; // index of the corresponding tflite.Tensor | ||
value_buffer:uint; // index of the buffer containing LUT values | ||
index_bitwidth:uint8; // bit-width of LUT indexes | ||
} | ||
|
||
root_type Metadata; |
Oops, something went wrong.