This document will walk you through steps to enable new operators in tflite2onnx
.
It's highly recommended to read the blog which introduces the
background of tflite2onnx
and general contribution guide.
I am sure it will help when enabling new operators.
Also, make sure that the operator is has not been enabled,
i.e. not included in Operator Support Status.
This is pretty simple:
pip install -r requirements.txt
First of all, we need a TensorFlow Lite model (model.tflite
) to get started.
Currently, to generate a TFLite model, we build a TensorFlow or Keras model,
and convert it into TFLite model.
Below is an example of generating a TFLite model which contains Concat
operator
only.
import tensorflow as tf
# operator inputs
a = tf.keras.Input(dtype='float32', name='a', shape=(1, 2, 3, 1))
b = tf.keras.Input(dtype='float32', name='b', shape=(1, 2, 3, 2))
c = tf.keras.Input(dtype='float32', name='c', shape=(1, 2, 3, 3))
# operator
concat = tf.keras.layers.Concatenate(axis=-1, name='output')([a, b, c])
# build Keras model
model = tf.keras.Model(inputs=[a, b, c], outputs=[concat])
# convert to TFLite model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# save it
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
Usually, tensor sizes are kept small to generate small model, as some of
them will be hosted in tflite2onnx
repository. In addition, we gave
each dimension different extent, such that layout errors can be easily
identified.
Once generated, it's recommended to use visualization tool such as Netron to verify if the tensors and operator are what you expect.
tflite2onnx
requires test for every operator to ensure that functionality
is not broken across development. The test for the operator is also very
helpful when enabling new operators.
Copy the newly generated model.tflite
to tflite2onnx
repository, put it
in ${tflite2onnx}/assets/tests
. Naming convesion is {operator}.{data type}.tflite
,
for example concat.float32.tflite
. The pattern {operator}.{data type}
will be used in our test. Also, {operator}
doesn't necessarily to be
operator type only, check files in ${tflite2onnx}/assets/tests
for details.
Add the pattern {operator}.{data type}
into operator test in
${tflite2onnx}/tests/test_ops.py
, depending the data layout attribution of
the operator (if you don't know which sub test shall the op goes, check the blog).
It would help to comment out all other operators and tests when trying around.
Invoke the test python tests/test_ops.py
. You should be able to see errors
like below, which indicates that one operator has not been supported.
wzh@Mac[✓]tflite2onnx.git (master*) $ python tests/test_ops.py
2020-11-09 20:51:00,439 D [tflite2onnx][convert.py:37] tflite: /Users/wzh/workspace/onnx/tflite2onnx.git/assets/tests/concat.float32.tflite
2020-11-09 20:51:00,439 D [tflite2onnx][convert.py:38] onnx: concat.float32.onnx
2020-11-09 20:51:00,439 D [tflite2onnx][model.py:21] Parsing the Model...
2020-11-09 20:51:00,439 D [tflite2onnx][graph.py:58] Parsing the Graph...
2020-11-09 20:51:00,439 D [tflite2onnx][graph.py:61] Parsing operator: 0
Traceback (most recent call last):
File "tests/test_ops.py", line 85, in <module>
test_ops_post_propagation()
File "tests/test_ops.py", line 64, in test_ops_post_propagation
end2end_test(op, 'NHWC')
File "tests/test_ops.py", line 16, in end2end_test
t2o.convert(tflm_path, onnx_name)
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/convert.py", line 44, in convert
model.convert(explicit_layouts)
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/model.py", line 39, in convert
self.parse()
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/model.py", line 31, in parse
g.parse()
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/graph.py", line 62, in parse
op = self.OPCFactory.create(i)
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/op/common.py", line 151, in create
raise NotImplementedError("Unsupported TFLite OP: {}".format(opcode))
NotImplementedError: Unsupported TFLite OP: 2
The 2
of NotImplementedError: Unsupported TFLite OP: 2
indicates which operator
has not been enabled yet. It is CONCATENATION
in
tflite.BuiltinOperator
.
With this, we can really start to write some code.
To start with, we add a operator converter class to handle
converting of this operator.
In this example, we created tflite.op.Concat
initially as:
class Concat(Operator):
TypeMapping = {
tflite.BuiltinOperator.CONCATENATION: 'Concat',
}
def __init__(self, TFactory, index):
super().__init__(TFactory, index)
self.attrs['axis'] = -1 # operator attribute
self.setInited()
@property
def type(self):
return 'Concat'
def parse(self):
logger.debug("Parsing %s...", self.type)
op = self.tflite
opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode()
assert(opcode is tflite.BuiltinOperator.CONCATENATION)
assert(op.InputsLength() >= 1)
assert(op.OutputsLength() == 1)
# TODO: parse tensors
self.setParsed()
def propagatableTensors(self):
return list()
def transform(self):
# TODO: handle layout transform
pass
This can be done by copying an existing similar operator, and make several modifications.
Operator.TypeMapping
maps TFLite operator type to ONNX operator type. Different TFLite operator may map to same ONNX operator type. An operator converter may be able to handle many TFLite operators.Operator.__init__()
collects the TFLite objects and initializes attributes of the operator. You can take ONNX Operator Schemas as reference. When it's done, the object switches to statusINITIALIZED
.Operator.type
is the operator type of ONNX. It's a string which you can find in operator examples in ONNX Operator Schemas - usually simply the operator type name, e.g.Concat
in our example. The type may be mapped from the TFLite operator type viaOperator.TypeMapping
or other information sometimes depending on the implementation.Operator.parse()
parses the tensors used by the operator, attributes of the operator. Let's left it to be done in next section. After finished, set object to statusPARSED
. Mostly, an object should not been parsed multiple times.Operator.propagatableTensors()
describes which tensors of this operator are layout propagatable. This is a bit tricky, please look into Data layout semantic and converting procedure.Operator.transform()
transforms operator attributes that are sensitive to layout. This is sort tricky which requires serious consideration. Leave it empty currently.
Now, let's integrate the operator converter class into framework.
This is simple (as we are trying to make it easy to extend :) ).
Import and register the operator converter class. In ${tflite2onnx}/op/__init__.py
, add code below.
from tflite2onnx.op.concat import Concat
# ...
OpFactory.register(Concat)
That's it! Simple!
Now let's try it. You may see errors like below - take it easy,
as we have not finish our jobs. But we can see that the Concat
class is parsing something (nothing so far). That means we have
enabled basic workflow for the operator.
2020-11-09 21:04:45,344 D [tflite2onnx][convert.py:37] tflite: /Users/wzh/workspace/onnx/tflite2onnx.git/assets/tests/concat.float32.tflite
2020-11-09 21:04:45,345 D [tflite2onnx][convert.py:38] onnx: concat.float32.onnx
2020-11-09 21:04:45,345 D [tflite2onnx][model.py:21] Parsing the Model...
2020-11-09 21:04:45,345 D [tflite2onnx][graph.py:58] Parsing the Graph...
2020-11-09 21:04:45,345 D [tflite2onnx][graph.py:61] Parsing operator: 0
2020-11-09 21:04:45,345 D [tflite2onnx][concat.py:27] Parsing [None](Concat)...
Traceback (most recent call last):
File "tests/test_ops.py", line 85, in <module>
test_ops_post_propagation()
File "tests/test_ops.py", line 64, in test_ops_post_propagation
end2end_test(op, 'NHWC')
File "tests/test_ops.py", line 16, in end2end_test
t2o.convert(tflm_path, onnx_name)
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/convert.py", line 44, in convert
model.convert(explicit_layouts)
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/model.py", line 39, in convert
self.parse()
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/model.py", line 31, in parse
g.parse()
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/graph.py", line 63, in parse
op.parse()
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/op/concat.py", line 47, in parse
self.setParsed()
File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/op/common.py", line 104, in setParsed
self.name = self.outputs[0].name if self.name is None else self.name
IndexError: list index out of range
TFLite and ONNX operator semantic are sometimes different. Make sure to review TFLite API documentation for operator option, and ONNX documents for operator attributes. To be noted, some operator option or attribute of one, could be described by input tensor in another.
For this Concat
example, it accepts several inputs and generate one output in both
TFLite and ONNX.
Unfortuanately, TFLite doesn't provide rich documents about operators, we may check
Compatible operations of TensorFlow and TFLite
and sometimes even the source code.
For options or attributes, we can check the
OperatorOption of TFLite.
In our Concat
example, it has two:
Axis
indicates concatenating on which dimension. This attribute is sensitive to how we handle layout issue. For example, if TFLite concatenates on axis-1
and has aNHWC
data layout - which means it's concatenating onC
dimension. While ONNX uses layoutNCHW
, ONNX version needs to concatenates on axis1
for it's dimensionC
in ONNX. This is needed whenConcat
feeds to aConv
. The interesting part is, if the model contains noConv
, for example has only oneConcat
in our case, we'd better keep it unchanged. We will discuss this more in dedicated document.FusedActivationFunction
describes which activation function has been fused into this operator. This is commen in operators likeConv
andFullyConnected
.
Parsing input and output tensors is simple, and we have provided well wrapped helpers to make this easy.
In our case, Concat
may has multiple inputs and one output, so just
for i in range(op.InputsLength()):
self.parseInput(i)
self.parseOutput(0)
Now, if you invoke the test, you may see that it completes without erros. You may also catch the tensors and graph log (well, we have many debug log to make investigation easier):
wzh@Mac[✗]tflite2onnx.git (master*) $ python tests/test_ops.py
2020-11-09 21:11:24,589 D [tflite2onnx][convert.py:37] tflite: /Users/wzh/workspace/onnx/tflite2onnx.git/assets/tests/concat.float32.tflite
2020-11-09 21:11:24,589 D [tflite2onnx][convert.py:38] onnx: concat.float32.onnx
2020-11-09 21:11:24,589 D [tflite2onnx][model.py:21] Parsing the Model...
2020-11-09 21:11:24,589 D [tflite2onnx][graph.py:58] Parsing the Graph...
2020-11-09 21:11:24,589 D [tflite2onnx][graph.py:61] Parsing operator: 0
2020-11-09 21:11:24,590 D [tflite2onnx][concat.py:27] Parsing [None](Concat)...
2020-11-09 21:11:24,590 D [tflite2onnx][tensor.py:103] Parsing a...
2020-11-09 21:11:24,590 D [tflite2onnx][tensor.py:103] Parsing b...
2020-11-09 21:11:24,591 D [tflite2onnx][tensor.py:103] Parsing c...
2020-11-09 21:11:24,591 D [tflite2onnx][tensor.py:103] Parsing Identity...
2020-11-09 21:11:24,591 D [tflite2onnx][model.py:40] Converting...
2020-11-09 21:11:24,591 D [tflite2onnx][graph.py:91] Converting...
2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:93] Handling data layout...
2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:130] Propragating layout across graph...
2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:141] Propagation: 4 tensors in total, 0 to walk, 4 at wild
2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:164] Propagation: wild tensors 4, ignored tensors 0
2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:104] Translating quantization semantic...
2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:112] Graph:
[OP] [Identity](Concat) attr{'axis': -1}: ['a', 'b', 'c'] -> ['Identity']
[Input] <a>(float32,[1, 1, 2, 3, 1]): {[]} -> {['[Identity](Concat)']}
[Input] <b>(float32,[1, 1, 2, 3, 2]): {[]} -> {['[Identity](Concat)']}
[Input] <c>(float32,[1, 1, 2, 3, 3]): {[]} -> {['[Identity](Concat)']}
[Output] <Identity>(float32,[1, 1, 2, 3, 6]): {['[Identity](Concat)']} -> {[]}
[Value Info] <a>(float32,[1, 1, 2, 3, 1]): {[]} -> {['[Identity](Concat)']}
[Value Info] <Identity>(float32,[1, 1, 2, 3, 6]): {['[Identity](Concat)']} -> {[]}
[Value Info] <b>(float32,[1, 1, 2, 3, 2]): {[]} -> {['[Identity](Concat)']}
[Value Info] <c>(float32,[1, 1, 2, 3, 3]): {[]} -> {['[Identity](Concat)']}
2020-11-09 21:11:24,592 D [tflite2onnx][common.py:111] Converting [Identity](Concat)...
2020-11-09 21:11:24,592 D [tflite2onnx][tensor.py:142] Converting <a>(float32,[1, 1, 2, 3, 1])...
2020-11-09 21:11:24,593 D [tflite2onnx][tensor.py:142] Converting <b>(float32,[1, 1, 2, 3, 2])...
2020-11-09 21:11:24,593 D [tflite2onnx][tensor.py:142] Converting <c>(float32,[1, 1, 2, 3, 3])...
2020-11-09 21:11:24,593 D [tflite2onnx][tensor.py:142] Converting <Identity>(float32,[1, 1, 2, 3, 6])...
2020-11-09 21:11:24,593 D [tflite2onnx][graph.py:118] Making ONNX...
2020-11-09 21:11:24,594 D [tflite2onnx][model.py:56] saving model as concat.float32.onnx
2020-11-09 21:11:24,614 I [tflite2onnx][convert.py:46] Converted ONNX model: concat.float32.onnx
2020-11-09 21:11:24,622 D [shrub][onnx.py:55] running concat.float32.onnx
2020-11-09 21:11:24,709 D [shrub][onnx.py:22] parsing concat.float32.onnx
2020-11-09 21:11:25,384 D [tensorflow][tpu_cluster_resolver.py:34] Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.
2020-11-09 21:11:26,565 D [shrub][tflite.py:98] running /Users/wzh/workspace/onnx/tflite2onnx.git/assets/tests/concat.float32.tflite
2020-11-09 21:11:26,568 D [shrub][tflite.py:104] Inputs: [{'name': 'a', 'index': 0, 'shape': array([1, 1, 2, 3, 1], dtype=int32), 'shape_signature': array([1, 1, 2, 3, 1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'b', 'index': 1, 'shape': array([1, 1, 2, 3, 2], dtype=int32), 'shape_signature': array([1, 1, 2, 3, 2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'c', 'index': 2, 'shape': array([1, 1, 2, 3, 3], dtype=int32), 'shape_signature': array([1, 1, 2, 3, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
2020-11-09 21:11:26,569 D [shrub][tflite.py:105] Outputs: [{'name': 'Identity', 'index': 3, 'shape': array([1, 1, 2, 3, 6], dtype=int32), 'shape_signature': array([1, 1, 2, 3, 6], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtyp
But, that's not the end necessarily, keep going!
TFLite model stores operator option with dedicated class per operator, which needs to be handled seperately.
Taking Concat
example, the options are aviable to obtain after a option
object has init from memory. See below.
op_opt = op.BuiltinOptions()
option = tflite.ConcatenationOptions()
option.Init(op_opt.Bytes, op_opt.Pos)
self.attrs['axis'] = option.Axis()
Each operator option has a funtion to extract the information, please refer
to the TFLite parser API. And all ONNX attributes are collected
in Operator.attrs
.
A TFLite operator option doesn't necessarily have a peer ONNX operator attribute, vice verse. A TFLite operator option may become ONNX operator input, or implicit ONNX operator semantic. Please do take care consideration for these functionalities. If you are not sure, take existing operator converter as reference, or open issue to ask.
Among all the options, fused activation function is one special, for which
we need to add one more ONNX operator to the graph. But, don't worry, it can be
handled by simply calling handleFusedActivation(self, option, ot)
, if that
operator has a FusedActivationFunction()
(Concat
example)
method of its option class. If that is the case, please don't add output tensor
of the operator directly, but do something like below.
handleFusedActivation(self, option, ot)
Do remeber to initialize ONNX attributes in Operator.__init__()
.
And do NOT miss any when creating ONNX operator.
Given that the data layout of TFLite models and ONNX models are NHWC and NCHW respective, some additional efforts are needed when enabling operators. If you have not read the blog or the layout handling story, not it's the time.
Operator.propagatableTensors()
describes which tensors of this operator
are layout propagatable. For most case like Concat
, all tensors are
propagatable, so we can write it like this.
def propagatableTensors(self):
return self.inputs + self.outputs
But for operators like Conv
, none of it's tensors is propagatable.
Be carefull for this part as it may require significant effort to debug
if it's not correctly coded at the begining.
Operator.transform()
transforms operator attributes that are sensitive
to layout. For most case, this function can be left as empty. But, just like
Operator.propagatableTensors()
, we need to check what should be done.
For Concat
, which requires attribute transform, we get the layout
description (source layout and target layout) from output, and transform
attribute axis
accordingly.
def transform(self):
logger.debug("Transforming %s...", self.shorty)
layout = self.outputs[0].layout
if layout is not None:
axis = self.attrs['axis']
axis = axis if axis >= 0 else (axis + len(layout.perm))
self.attrs['axis'] = layout.perm.index(axis)
To be noted, the layout issue is case by case, this document only shows
Concat
as an example. You may find other operator converter classes
as a hint for the operator you are trying to enable.
If you have any question, just open issue to discuss.
Congratulation! You have basically finished the implementation of a new operator. If everthing looks good, open pull request. Let your work empower the community.
Thank you for your contribution!
Cheers!