Skip to content

Commit

Permalink
fix(compression): use single value table for per-tensor quantized ten…
Browse files Browse the repository at this point in the history
…sors

Compress using a single value table when a tensor is per-tensor
quantized, as indicated by the presence of only one quantization
scale and zero point. Update unit tests accordingly and augment
`test_models` to accommodate additional quantization fields.

Abandon the logic that a tensor should be compressed along the
NHWC channel dimension if the quantization parameters do not
specify an axis. Instead, fail with an error if the compression
axis cannot be inferred from the quantization parameters.

The interpreter already expects a single value table when a
tensor is per-tensor quantized.

BUG=part of tensorflow#2636
  • Loading branch information
rkuester committed Dec 19, 2024
1 parent e3ac890 commit 25cd19b
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 36 deletions.
89 changes: 62 additions & 27 deletions tensorflow/lite/micro/compression/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import bitarray.util
from dataclasses import dataclass, field
import sys
from typing import ByteString, Iterable
from typing import ByteString, Iterable, Optional

import absl.app
import absl.flags
Expand Down Expand Up @@ -107,7 +107,7 @@ def _add_subgraph(self):

@dataclass
class _LutCompressedArray:
compression_axis: int = 0
compression_axis: Optional[int] = None
lookup_tables: list[np.ndarray] = field(default_factory=list)
indices: np.ndarray = field(default_factory=lambda: np.array([]))

Expand All @@ -121,27 +121,46 @@ def index_bitwidth(self) -> int:
return max_index.bit_length() or 1


def _lut_compress_array(tensor: np.ndarray, axis: int) -> _LutCompressedArray:
"""Compresses using a lookup table per subarray along the given axis.
def _lut_compress_array(tensor: np.ndarray,
axis: Optional[int]) -> _LutCompressedArray:
"""Compresses the given tensor using lookup tables.
Compressing a tensor with a lookup table per subarray along a particular axis
is analogous to quantizing a tensor with different quantization parameters
per subarray along a particular axis (dimension).
Args:
tensor (np.ndarray): The tensor to be compressed.
axis (Optional[int]): The axis along which to compress the tensor. If an
axis is given, a lookup table is created for each slice along the
axis. If axis is None, a single lookup table is used for the entire
tensor.
Compressing a tensor with a lookup table per slice along a
particular axis is analogous to quantizing a tensor with different
quantization parameters per slice along a particular axis (dimension).
Returns:
_LutCompressedArray: An object containing the compressed tensor data,
including the lookup tables and indices.
"""
compressed = _LutCompressedArray()
compressed.compression_axis = axis

# Iterate over subarrays along the compression axis
subarray_indices = []
for subarray in np.moveaxis(tensor, axis, 0):
values, indices = np.unique(subarray, return_inverse=True)
if axis is None:
# Compute unique values and indices for the entire tensor
values, indices = np.unique(tensor, return_inverse=True)
compressed.lookup_tables.append(values)
indices = indices.reshape(subarray.shape)
subarray_indices.append(indices)

# Reconstruct a tensor of indices from the subarrays
stacked = np.stack(subarray_indices, axis=0)
compressed.indices = np.moveaxis(stacked, 0, axis)
compressed.indices = indices.reshape(tensor.shape)
else:
# Iterate over slices along the compression axis
slice_indices = []
for slice in np.moveaxis(tensor, axis, 0):
values, indices = np.unique(slice, return_inverse=True)
compressed.lookup_tables.append(values)
indices = indices.reshape(slice.shape)
slice_indices.append(indices)

# Reconstruct a tensor of indices from the slices
stacked = np.stack(slice_indices, axis=0)
compressed.indices = np.moveaxis(stacked, 0, axis)

return compressed

Expand All @@ -155,18 +174,34 @@ def _check_lut_compression(compression) -> spec.LookUpTableCompression:
return compression[0]


def _identify_compression_axis(tensor: model_facade._Tensor) -> int:
"""Finds the axis along which to compress.
def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]:
"""Determines the axis along which to compress.
The axis along which to compress is inferred from the tensor's quantization
parameters.
Returns:
The axis along which to compress, or None to indicate one value table for
the entire tensor.
Use the quantization axis, else the NWHC channel dimension. If necessary,
an user-specified override could be added to the compression spec schema.
Raises:
CompressionError: If the axis cannot be determined.
"""
if tensor.quantization is not None:
axis = tensor.quantization.quantizedDimension
else:
axis = tensor.array.ndim - 1
q = tensor.quantization
if q is not None \
and q.scale is not None \
and q.quantizedDimension < len(tensor.shape):
quantization_channels = len(q.scale)
if quantization_channels == 1:
# Use one value table for the entire tensor
return None

if quantization_channels == tensor.shape[q.quantizedDimension]:
return q.quantizedDimension

return axis
raise CompressionError(
f"Invalid or no quanitzation parameters from which to "
f"infer the axis along which tensor should be compressed.")


def _check_bitwidth(compressed: int, specified: int, spec: spec.Tensor):
Expand Down Expand Up @@ -204,7 +239,7 @@ def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray:
Pack the value tables of a LutCompressedArray into a bytes object in the
format writable to a value_table buffer in the .tflite flatbuffer. The
tables, one per subarray, are concatinated.
tables are concatinated.
"""
buffer = bytearray()
for t in tables:
Expand Down
99 changes: 95 additions & 4 deletions tensorflow/lite/micro/compression/compress_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,40 +200,87 @@ def test_multiple_tables_with_padding(self):
"shape": (16, 1),
"type": tflite.TensorType.UINT8,
"buffer": 1,
"quantization": {
"quantized_dimension": 1,
"scale": (1,),
"zero_point": (0,),
},
},
1: {
"shape": (16, 1),
"type": tflite.TensorType.INT8,
"buffer": 2,
"quantization": {
"quantized_dimension": 1,
"scale": (1,),
"zero_point": (0,),
},
},
2: {
"shape": (16, 1),
"type": tflite.TensorType.INT16,
"buffer": 3,
"quantization": {
"quantized_dimension": 1,
"scale": (1,),
"zero_point": (0,),
},
},
3: {
"shape": (16, 1),
"type": tflite.TensorType.INT32,
"buffer": 4,
"quantization": {
"quantized_dimension": 1,
"scale": (1,),
"zero_point": (0,),
},
},
4: {
"shape": (16, 1),
"type": tflite.TensorType.INT32,
"buffer": 5,
"quantization": {
"quantized_dimension": 1,
"scale": (1,),
"zero_point": (0,),
},
},
5: {
"shape": (4, 5),
"type": tflite.TensorType.INT16,
"buffer": 6,
"quantization": {
"quantized_dimension": 1,
"scale": (1, 1, 1, 1, 1),
"zero_point": (0, 0, 0, 0, 0),
},
},
6: {
"shape": (5, 4),
"type": tflite.TensorType.INT16,
"buffer": 7,
"quantization": {
"quantized_dimension": 0,
"scale": (1, 1, 1, 1, 1),
"zero_point": (0, 0, 0, 0, 0),
},
},
7: {
"shape": (5, 4),
"type": tflite.TensorType.INT16,
"buffer": 8,
"quantization": {
"quantized_dimension": 0,
"scale": (1,),
"zero_point": (0,),
},
},
8: {
"shape": (16, 1),
"type": tflite.TensorType.UINT8,
"buffer": 9,
},
},
},
},
Expand All @@ -260,6 +307,14 @@ def test_multiple_tables_with_padding(self):
(9, 10, 11, 12),
(13, 14, 15, 16),
(17, 18, 19, 20)), dtype=np.dtype("<i2")),

8: np.array(((1, 2, 3, 4),
(1, 2, 3, 4),
(1, 2, 3, 4),
(1, 2, 3, 4),
(1, 2, 3, 4)), dtype=np.dtype("<i2")),

9: np.array(range(16), dtype=np.dtype("<u1")),
},
}

Expand Down Expand Up @@ -297,6 +352,11 @@ def test_multiple_tables_with_padding(self):
tensor=6,
compression=[spec.LookUpTableCompression(index_bitwidth=2)],
),
spec.Tensor( # spec 6
subgraph=0,
tensor=7,
compression=[spec.LookUpTableCompression(index_bitwidth=2)],
),
]
# yapf: enable

Expand Down Expand Up @@ -362,6 +422,18 @@ def test_invalid_tensor_spec(self):
self.assertRaises(compress.CompressionError,
lambda: compress.compress(self.flatbuffer, specs))

def test_no_axis(self):
"""Raises if no quantization from which to infer compression axis."""
specs = [
spec.Tensor(
subgraph=0,
tensor=8,
compression=[spec.LookUpTableCompression(index_bitwidth=4)],
),
]
self.assertRaises(compress.CompressionError,
lambda: compress.compress(self.flatbuffer, specs))


class TestLutCompressedArray(tf.test.TestCase):

Expand Down Expand Up @@ -519,8 +591,8 @@ def test_compressed_int32(self):
expected_values = np.array(range(-160_016, -160_000), dtype="<i4")
self.assertAllEqual(values, expected_values)

def test_channel_axis(self):
"""Compression along the NWHC channel axis when no quanitzation axis."""
def test_axis_1(self):
"""Compression along quanitzation_dimension == 1."""
bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=5)
self.assertEqual(bitwidth, 2)

Expand All @@ -537,8 +609,8 @@ def test_channel_axis(self):
expected_values = np.array(range(1, 21), dtype=np.dtype("<i2"))
self.assertAllEqual(values, expected_values)

def test_quantization_axis(self):
"""Compression along the quanitzation axis."""
def test_axis_0(self):
"""Compression along quanitzation_dimension == 0."""
bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=6)
self.assertEqual(bitwidth, 2)

Expand All @@ -556,6 +628,25 @@ def test_quantization_axis(self):
expected_values = np.array(range(1, 21), dtype=np.dtype("<i2"))
self.assertAllEqual(values, expected_values)

def test_per_tensor(self):
"""Compression with one value table per tensor."""
bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=7)
self.assertEqual(bitwidth, 2)

# yapf: disable
expected_indices = self._make_indices("""
00 01 10 11
00 01 10 11
00 01 10 11
00 01 10 11
00 01 10 11
""")
# yapf: enable
self.assertEqual(indices, expected_indices)

expected_values = np.array(range(1, 5), dtype=np.dtype("<i2"))
self.assertAllEqual(values, expected_values)


if __name__ == "__main__":
tf.test.main()
12 changes: 7 additions & 5 deletions tensorflow/lite/micro/compression/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,14 @@ def build(model_definition: dict) -> bytearray:
tensor_t.type = tensor["type"]
tensor_t.buffer = tensor["buffer"]

try:
d = tensor["quantization"]["quantized_dimension"]
if "quantization" in tensor:
tensor_t.quantization = tflite.QuantizationParametersT()
tensor_t.quantization.quantizedDimension = d
except KeyError:
tensor_t.quantization = None
tensor_t.quantization.quantizedDimension = \
tensor["quantization"].get("quantized_dimension", None)
tensor_t.quantization.scale = \
tensor["quantization"].get("scale", None)
tensor_t.quantization.zeroPoint = \
tensor["quantization"].get("zero_point", None)

subgraph_t.tensors.append(tensor_t)

Expand Down

0 comments on commit 25cd19b

Please sign in to comment.