Unverified Commit 8599f7c6 by Thierry Moreau Committed by GitHub

[TFLite] Model importer to be compatible with tflite 2.1.0 (#5497)

parent 360027d2
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
import os import os
from tvm import relay from tvm import relay
from tvm.contrib.download import download_testdata from tvm.contrib.download import download_testdata
import tflite.Model
################################################ ################################################
...@@ -49,7 +48,12 @@ model_file = os.path.join(model_dir, "mobilenet_v2_1.4_224.tflite") ...@@ -49,7 +48,12 @@ model_file = os.path.join(model_dir, "mobilenet_v2_1.4_224.tflite")
# get TFLite model from buffer # get TFLite model from buffer
tflite_model_buf = open(model_file, "rb").read() tflite_model_buf = open(model_file, "rb").read()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) try:
import tflite
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
############################## ##############################
......
...@@ -2524,7 +2524,7 @@ def from_tflite(model, shape_dict, dtype_dict): ...@@ -2524,7 +2524,7 @@ def from_tflite(model, shape_dict, dtype_dict):
Parameters Parameters
---------- ----------
model: model:
tflite.Model.Model tflite.Model or tflite.Model.Model (depending on tflite version)
shape_dict : dict of str to int list/tuple shape_dict : dict of str to int list/tuple
Input shapes of the model. Input shapes of the model.
...@@ -2541,12 +2541,18 @@ def from_tflite(model, shape_dict, dtype_dict): ...@@ -2541,12 +2541,18 @@ def from_tflite(model, shape_dict, dtype_dict):
The parameter dict to be used by relay The parameter dict to be used by relay
""" """
try: try:
import tflite.Model
import tflite.SubGraph import tflite.SubGraph
import tflite.BuiltinOperator import tflite.BuiltinOperator
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(model, tflite.Model.Model)
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite
assert isinstance(model, tflite.Model)
except TypeError:
import tflite.Model
assert isinstance(model, tflite.Model.Model)
# keep the same as tflite # keep the same as tflite
assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)" assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)"
......
...@@ -76,14 +76,16 @@ def get_real_image(im_height, im_width): ...@@ -76,14 +76,16 @@ def get_real_image(im_height, im_width):
def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
out_names=None): out_names=None):
""" Generic function to compile on relay and execute on tvm """ """ Generic function to compile on relay and execute on tvm """
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try: try:
import tflite.Model import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
# get TFLite model from buffer
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
input_data = convert_to_list(input_data) input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node) input_node = convert_to_list(input_node)
......
...@@ -21,25 +21,12 @@ Compile TFLite Models ...@@ -21,25 +21,12 @@ Compile TFLite Models
This article is an introductory tutorial to deploy TFLite models with Relay. This article is an introductory tutorial to deploy TFLite models with Relay.
To get started, Flatbuffers and TFLite package needs to be installed as prerequisites. To get started, TFLite package needs to be installed as prerequisite.
A quick solution is to install Flatbuffers via pip
.. code-block:: bash .. code-block:: bash
pip install flatbuffers --user # install tflite
pip install tflite=2.1.0 --user
To install TFlite packages, you could use our prebuilt wheel:
.. code-block:: bash
# For python3:
wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py3-none-any.whl
pip3 install -U tflite-1.13.1-py3-none-any.whl --user
# For python2:
wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py2-none-any.whl
pip install -U tflite-1.13.1-py2-none-any.whl --user
or you could generate TFLite package yourself. The steps are the following: or you could generate TFLite package yourself. The steps are the following:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment