Commit 24713bde by ziheng Committed by Tianqi Chen

[CONTRIB] TFLite Runtime (#4439)

parent f2143644
......@@ -63,6 +63,8 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
tvm_option(USE_TFLITE "Build with tflite support" OFF)
tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none)
# include directories
include_directories(${CMAKE_INCLUDE_PATH})
......@@ -257,6 +259,7 @@ include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
include(cmake/modules/contrib/TFLite.cmake)
if(NOT MSVC)
include(CheckCXXCompilerFlag)
......
......@@ -145,6 +145,15 @@ set(USE_RANDOM OFF)
# Whether use NNPack
set(USE_NNPACK OFF)
# Possible values:
# - ON: enable tflite with cmake's find search
# - OFF: disable tflite
# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library
set(USE_TFLITE OFF)
# /path/to/tensorflow: tensorflow root path when use tflite library
set(USE_TENSORFLOW_PATH none)
# Whether use CuDNN
set(USE_CUDNN OFF)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
if(NOT USE_TFLITE STREQUAL "OFF")
message(STATUS "Build with contrib.tflite")
if (USE_TENSORFLOW_PATH STREQUAL "none")
set(USE_TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow)
endif()
file(GLOB TFLITE_CONTRIB_SRC src/runtime/contrib/tflite/*.cc)
list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC})
include_directories(${USE_TENSORFLOW_PATH})
if (USE_TFLITE STREQUAL "ON")
set(USE_TFLITE ${USE_TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib)
endif()
find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${USE_TFLITE})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${TFLITE_CONTRIB_LIB})
list(APPEND TVM_RUNTIME_LINKER_LIBS rt dl flatbuffers)
endif()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TFLite runtime that load and run tflite models."""
from .._ffi.function import get_global_func
from ..rpc import base as rpc_base
def create(tflite_model_bytes, ctx):
"""Create a runtime executor module given a tflite model and context.
Parameters
----------
tflite_model_byte : bytes
The tflite model to be deployed in bytes string format.
ctx : TVMContext
The context to deploy the module. It can be local or remote when there
is only one TVMContext.
Returns
-------
tflite_runtime : TFLiteModule
Runtime tflite module that can be used to execute the tflite model.
"""
device_type = ctx.device_type
if device_type >= rpc_base.RPC_SESS_MASK:
fcreate = ctx._rpc_sess.get_function("tvm.tflite_runtime.create")
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
fcreate = get_global_func("tvm.tflite_runtime.create")
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
class TFLiteModule(object):
"""Wrapper runtime module.
This is a thin wrapper of the underlying TVM module.
you can also directly call set_input, run, and get_output
of underlying module functions
Parameters
----------
module : Module
The interal tvm module that holds the actual tflite functions.
Attributes
----------
module : Module
The interal tvm module that holds the actual tflite functions.
"""
def __init__(self, module):
self.module = module
self._set_input = module["set_input"]
self._invoke = module["invoke"]
self._get_output = module["get_output"]
self._allocate_tensors = module["allocate_tensors"]
def set_input(self, index, value):
"""Set inputs to the module via kwargs
Parameters
----------
key : int or str
The input key
value : the input value.
The input key
params : dict of str to NDArray
Additonal arguments
"""
self._set_input(index, value)
def invoke(self):
"""Invoke forward execution of the model
Parameters
----------
input_dict: dict of str to NDArray
List of input values to be feed to
"""
self._invoke()
def allocate_tensors(self):
"""Allocate space for all tensors.
"""
self._allocate_tensors()
def get_output(self, index):
"""Get index-th output to out
Parameters
----------
index : int
The output index
"""
return self._get_output(index)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tflite_runtime.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/dtype.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h>
#include "tflite_runtime.h"
namespace tvm {
namespace runtime {
#define TVM_DTYPE_DISPATCH(type, DType, ...) \
if (type == Float(64)) { \
typedef double DType; \
{__VA_ARGS__} \
} else if (type == Float(32)) { \
typedef float DType; \
{__VA_ARGS__} \
} else if (type == Float(16)) { \
typedef uint16_t DType; \
{__VA_ARGS__} \
} else if (type == Int(64)) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else if (type == Int(32)) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if (type == Int(16)) { \
typedef int16_t DType; \
{__VA_ARGS__} \
} else if (type == Int(8)) { \
typedef int8_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(64)) { \
typedef uint64_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(32)) { \
typedef uint32_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(16)) { \
typedef uint16_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(8)) { \
typedef uint8_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "unknown data type " << type; \
}
DataType TfLiteDType2TVMDType(TfLiteType dtype) {
switch (dtype) {
case kTfLiteFloat32:
return Float(32);
case kTfLiteInt32:
return Int(32);
case kTfLiteInt64:
return Int(64);
case kTfLiteInt16:
return Int(16);
case kTfLiteInt8:
return Int(8);
case kTfLiteUInt8:
return UInt(8);
case kTfLiteFloat16:
return Float(16);
default:
LOG(FATAL) << "tflite data type not support yet: " << dtype;
return Float(32);
}
}
void TFLiteRuntime::Init(const std::string& tflite_model_bytes,
TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str();
size_t buffer_size = tflite_model_bytes.size();
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
ctx_ = ctx;
}
void TFLiteRuntime::AllocateTensors() {
interpreter_->AllocateTensors();
}
void TFLiteRuntime::Invoke() {
interpreter_->Invoke();
}
void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
DataType dtype(data_in->dtype);
TVM_DTYPE_DISPATCH(dtype, DType, {
DType* dest = interpreter_->typed_input_tensor<DType>(index);
DType* src = static_cast<DType*>(data_in->data);
CHECK(data_in->strides == NULL);
int64_t size = 1;
for (int64_t i = 0; i < data_in->ndim; ++i) {
size *= data_in->shape[i];
}
for (int64_t i = 0; i < size; ++i) {
dest[i] = src[i];
}
});
}
NDArray TFLiteRuntime::GetOutput(int index) const {
TfLiteTensor* output = interpreter_->output_tensor(index);
DataType dtype = TfLiteDType2TVMDType(output->type);
TfLiteIntArray* dims = output->dims;
int64_t size = 1;
std::vector<int64_t> shape;
for (int i = 0; i < dims->size; ++i) {
shape.push_back(dims->data[i]);
size *= dims->data[i];
}
NDArray ret = NDArray::Empty(shape, dtype, ctx_);
TVM_DTYPE_DISPATCH(dtype, DType, {
DType* dest = static_cast<DType*>(ret->data);
DType* src = interpreter_->typed_output_tensor<DType>(index);
for (int64_t i = 0; i < size; ++i) {
dest[i] = src[i];
}
});
return ret;
}
PackedFunc TFLiteRuntime::GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
// Return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = args[0];
CHECK_GE(in_idx, 0);
this->SetInput(in_idx, args[1]);
});
} else if (name == "get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetOutput(args[0]);
});
} else if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Invoke();
});
} else if (name == "allocate_tensors") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->AllocateTensors();
});
} else {
return PackedFunc();
}
}
Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes,
TVMContext ctx) {
auto exec = make_object<TFLiteRuntime>();
exec->Init(tflite_model_bytes, ctx);
return Module(exec);
}
TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = TFLiteRuntimeCreate(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Tflite runtime that can run tflite model
* containing only tvm PackedFunc.
* \file tflite_runtime.h
*/
#ifndef TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_
#define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_
#include <dlpack/dlpack.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <vector>
#include <string>
#include <memory>
namespace tvm {
namespace runtime {
/*!
* \brief Tflite runtime.
*
* This runtime can be acccesibly in various language via
* TVM runtime PackedFunc API.
*/
class TFLiteRuntime : public ModuleNode {
public:
/*!
* \brief Get member function to front-end
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
virtual PackedFunc GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self);
/*!
* \return The type key of the executor.
*/
const char* type_key() const final {
return "TFLiteRuntime";
}
/*!
* \brief Update allocations for all tenssors. This is relatively expensive.
*/
void AllocateTensors();
/*!
* \brief Invoke the internal tflite interpreter and run the whole model in
* dependency order.
*/
void Invoke();
/*!
* \brief Initialize the tflite runtime with tflite model and context.
* \param tflite_model_bytes The tflite model.
* \param ctx The context where the tflite model will be executed on.
*/
void Init(const std::string& tflite_model_bytes,
TVMContext ctx);
/*!
* \brief set index-th input to the model.
* \param index The input index.
* \param data_in The input data.
*/
void SetInput(int index, DLTensor* data_in);
/*!
* \brief Return NDArray for given input index.
* \param index The input index.
*
* \return NDArray corresponding to given input node index.
*/
NDArray GetInput(int index) const;
/*!
* \brief Return NDArray for given output index.
* \param index The output index.
*
* \return NDArray corresponding to given output node index.
*/
NDArray GetOutput(int index) const;
private:
std::unique_ptr<tflite::Interpreter> interpreter_;
TVMContext ctx_;
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
from tvm import rpc
from tvm.contrib import util, tflite_runtime
# import tensorflow as tf
# import tflite_runtime.interpreter as tflite
def skipped_test_tflite_runtime():
def create_tflite_model():
root = tf.Module()
root.const = tf.constant([1., 2.], tf.float32)
root.f = tf.function(lambda x: root.const * x)
input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32)
concrete_func = root.f.get_concrete_function(input_signature)
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
return tflite_model
def check_verify():
tflite_fname = "model.tflite"
tflite_model = create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
print(tflite_model_path)
open(tflite_model_path, 'wb').write(tflite_model)
# inference via tflite interpreter python apis
print('interpreter')
interpreter = tflite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])
print('tvm tflite runtime')
# inference via tvm tflite runtime
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
runtime.allocate_tensors()
runtime.set_input(0, tvm.nd.array(tflite_input))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)
def check_remote():
tflite_fname = "model.tflite"
tflite_model = create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)
# inference via tflite interpreter python apis
interpreter = tflite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])
# inference via remote tvm tflite runtime
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
a = remote.upload(tflite_model_path)
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.allocate_tensors()
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)
check_verify()
check_remote()
if __name__ == "__main__":
# skipped_test_tflite_runtime()
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment