Unverified Commit 8599f7c6 by Thierry Moreau Committed by GitHub

[TFLite] Model importer to be compatible with tflite 2.1.0 (#5497)

parent 360027d2
......@@ -18,7 +18,6 @@
import os
from tvm import relay
from tvm.contrib.download import download_testdata
import tflite.Model
################################################
......@@ -49,7 +48,12 @@ model_file = os.path.join(model_dir, "mobilenet_v2_1.4_224.tflite")
# get TFLite model from buffer
tflite_model_buf = open(model_file, "rb").read()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
try:
import tflite
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
##############################
......
......@@ -2524,7 +2524,7 @@ def from_tflite(model, shape_dict, dtype_dict):
Parameters
----------
model:
tflite.Model.Model
tflite.Model or tflite.Model.Model (depending on tflite version)
shape_dict : dict of str to int list/tuple
Input shapes of the model.
......@@ -2541,12 +2541,18 @@ def from_tflite(model, shape_dict, dtype_dict):
The parameter dict to be used by relay
"""
try:
import tflite.Model
import tflite.SubGraph
import tflite.BuiltinOperator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(model, tflite.Model.Model)
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite
assert isinstance(model, tflite.Model)
except TypeError:
import tflite.Model
assert isinstance(model, tflite.Model.Model)
# keep the same as tflite
assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)"
......
......@@ -76,14 +76,16 @@ def get_real_image(im_height, im_width):
def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
out_names=None):
""" Generic function to compile on relay and execute on tvm """
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except ImportError:
raise ImportError("The tflite package must be installed")
# get TFLite model from buffer
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
......
......@@ -21,25 +21,12 @@ Compile TFLite Models
This article is an introductory tutorial to deploy TFLite models with Relay.
To get started, Flatbuffers and TFLite package needs to be installed as prerequisites.
A quick solution is to install Flatbuffers via pip
To get started, TFLite package needs to be installed as prerequisite.
.. code-block:: bash
pip install flatbuffers --user
To install TFlite packages, you could use our prebuilt wheel:
.. code-block:: bash
# For python3:
wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py3-none-any.whl
pip3 install -U tflite-1.13.1-py3-none-any.whl --user
# For python2:
wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py2-none-any.whl
pip install -U tflite-1.13.1-py2-none-any.whl --user
# install tflite
pip install tflite=2.1.0 --user
or you could generate TFLite package yourself. The steps are the following:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment