Commit 31021d2b by Thierry Moreau Committed by Tianqi Chen

[Runtime] EdgeTPU runtime for Coral Boards (#4698)

parent c7a83199
......@@ -154,6 +154,11 @@ set(USE_TFLITE OFF)
# /path/to/tensorflow: tensorflow root path when use tflite library
set(USE_TENSORFLOW_PATH none)
# Possible values:
# - OFF: disable tflite support for edgetpu
# - /path/to/edgetpu: use specific path to edgetpu library
set(USE_EDGETPU OFF)
# Whether use CuDNN
set(USE_CUDNN OFF)
......
......@@ -25,6 +25,15 @@ if(NOT USE_TFLITE STREQUAL "OFF")
list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC})
include_directories(${USE_TENSORFLOW_PATH})
# Additional EdgeTPU libs
if (NOT USE_EDGETPU STREQUAL "OFF")
message(STATUS "Build with contrib.edgetpu")
file(GLOB EDGETPU_CONTRIB_SRC src/runtime/contrib/edgetpu/*.cc)
list(APPEND RUNTIME_SRCS ${EDGETPU_CONTRIB_SRC})
include_directories(${USE_EDGETPU}/libedgetpu)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${USE_EDGETPU}/libedgetpu/direct/aarch64/libedgetpu.so.1)
endif()
if (USE_TFLITE STREQUAL "ON")
set(USE_TFLITE ${USE_TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib)
endif()
......
......@@ -18,7 +18,7 @@
from .._ffi.function import get_global_func
from ..rpc import base as rpc_base
def create(tflite_model_bytes, ctx):
def create(tflite_model_bytes, ctx, runtime_target='cpu'):
"""Create a runtime executor module given a tflite model and context.
Parameters
----------
......@@ -27,16 +27,25 @@ def create(tflite_model_bytes, ctx):
ctx : TVMContext
The context to deploy the module. It can be local or remote when there
is only one TVMContext.
runtime_target: str
Execution target of TFLite runtime: either `cpu` or `edge_tpu`.
Returns
-------
tflite_runtime : TFLiteModule
Runtime tflite module that can be used to execute the tflite model.
"""
device_type = ctx.device_type
if runtime_target == 'edge_tpu':
runtime_func = "tvm.edgetpu_runtime.create"
else:
runtime_func = "tvm.tflite_runtime.create"
if device_type >= rpc_base.RPC_SESS_MASK:
fcreate = ctx._rpc_sess.get_function("tvm.tflite_runtime.create")
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
fcreate = get_global_func("tvm.tflite_runtime.create")
fcreate = ctx._rpc_sess.get_function(runtime_func)
else:
fcreate = get_global_func(runtime_func)
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
......@@ -50,12 +59,12 @@ class TFLiteModule(object):
Parameters
----------
module : Module
The interal tvm module that holds the actual tflite functions.
The internal tvm module that holds the actual tflite functions.
Attributes
----------
module : Module
The interal tvm module that holds the actual tflite functions.
The internal tvm module that holds the actual tflite functions.
"""
def __init__(self, module):
......@@ -63,7 +72,6 @@ class TFLiteModule(object):
self._set_input = module["set_input"]
self._invoke = module["invoke"]
self._get_output = module["get_output"]
self._allocate_tensors = module["allocate_tensors"]
def set_input(self, index, value):
"""Set inputs to the module via kwargs
......@@ -91,12 +99,6 @@ class TFLiteModule(object):
"""
self._invoke()
def allocate_tensors(self):
"""Allocate space for all tensors.
"""
self._allocate_tensors()
def get_output(self, index):
"""Get index-th output to out
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file edgetpu_runtime.cc
*/
#include <tvm/runtime/registry.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h>
#include <edgetpu.h>
#include "edgetpu_runtime.h"
namespace tvm {
namespace runtime {
void EdgeTPURuntime::Init(const std::string& tflite_model_bytes,
TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str();
size_t buffer_size = tflite_model_bytes.size();
// Load compiled model as a FlatBufferModel
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
// Build resolver
tflite::ops::builtin::BuiltinOpResolver resolver;
// Init EdgeTPUContext object
edgetpu_context_ = edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice();
// Add custom edgetpu ops to resolver
resolver.AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp());
// Build interpreter
TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
CHECK_TFLITE_STATUS(status) << "Failed to build interpreter.";
// Bind EdgeTPU context with interpreter.
interpreter_->SetExternalContext(kTfLiteEdgeTpuContext, edgetpu_context_.get());
interpreter_->SetNumThreads(1);
// Allocate tensors
status = interpreter_->AllocateTensors();
CHECK_TFLITE_STATUS(status) << "Failed to allocate tensors.";
ctx_ = ctx;
}
Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes,
TVMContext ctx) {
auto exec = make_object<EdgeTPURuntime>();
exec->Init(tflite_model_bytes, ctx);
return Module(exec);
}
TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = EdgeTPURuntimeCreate(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief EdgeTPU runtime that can run tflite model compiled
* for EdgeTPU containing only tvm PackedFunc.
* \file edgetpu_runtime.h
*/
#ifndef TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
#define TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
#include <string>
#include <memory>
#include "../tflite/tflite_runtime.h"
namespace tvm {
namespace runtime {
/*!
* \brief EdgeTPU runtime.
*
* This runtime can be accessed in various languages via
* the TVM runtime PackedFunc API.
*/
class EdgeTPURuntime : public TFLiteRuntime {
public:
/*!
* \return The type key of the executor.
*/
const char* type_key() const final {
return "EdgeTPURuntime";
}
/*!
* \brief Initialize the edge TPU tflite runtime with tflite model and context.
* \param tflite_model_bytes The tflite model.
* \param ctx The context where the tflite model will be executed on.
*/
void Init(const std::string& tflite_model_bytes,
TVMContext ctx);
private:
std::shared_ptr<edgetpu::EdgeTpuContext> edgetpu_context_;
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
......@@ -21,7 +21,6 @@
* \file tflite_runtime.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/dtype.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h>
......@@ -33,37 +32,37 @@ namespace tvm {
namespace runtime {
#define TVM_DTYPE_DISPATCH(type, DType, ...) \
if (type == DataType::Float(64)) { \
if (type == DataType::Float(64)) { \
typedef double DType; \
{__VA_ARGS__} \
} else if (type == DataType::Float(32)) { \
} else if (type == DataType::Float(32)) { \
typedef float DType; \
{__VA_ARGS__} \
} else if (type == DataType::Float(16)) { \
} else if (type == DataType::Float(16)) { \
typedef uint16_t DType; \
{__VA_ARGS__} \
} else if (type == DataType::Int(64)) { \
} else if (type == DataType::Int(64)) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else if (type == DataType::Int(32)) { \
} else if (type == DataType::Int(32)) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if (type == DataType::Int(16)) { \
} else if (type == DataType::Int(16)) { \
typedef int16_t DType; \
{__VA_ARGS__} \
} else if (type == DataType::Int(8)) { \
} else if (type == DataType::Int(8)) { \
typedef int8_t DType; \
{__VA_ARGS__} \
} else if (type == DataType::UInt(64)) { \
} else if (type == DataType::UInt(64)) { \
typedef uint64_t DType; \
{__VA_ARGS__} \
} else if (type == DataType::UInt(32)) { \
} else if (type == DataType::UInt(32)) { \
typedef uint32_t DType; \
{__VA_ARGS__} \
} else if (type == DataType::UInt(16)) { \
} else if (type == DataType::UInt(16)) { \
typedef uint16_t DType; \
{__VA_ARGS__} \
} else if (type == DataType::UInt(8)) { \
} else if (type == DataType::UInt(8)) { \
typedef uint8_t DType; \
{__VA_ARGS__} \
} else { \
......@@ -79,9 +78,9 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
case kTfLiteInt64:
return DataType::Int(64);
case kTfLiteInt16:
returnDataType::Int(16);
return DataType::Int(16);
case kTfLiteInt8:
returnDataType::Int(8);
return DataType::Int(8);
case kTfLiteUInt8:
return DataType::UInt(8);
case kTfLiteFloat16:
......@@ -92,7 +91,6 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
}
}
void TFLiteRuntime::Init(const std::string& tflite_model_bytes,
TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str();
......@@ -100,12 +98,14 @@ void TFLiteRuntime::Init(const std::string& tflite_model_bytes,
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
ctx_ = ctx;
}
// Build interpreter
TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
CHECK_TFLITE_STATUS(status) << "Failed to build interpreter.";
// Allocate tensors
status = interpreter_->AllocateTensors();
CHECK_TFLITE_STATUS(status) << "Failed to allocate tensors.";
void TFLiteRuntime::AllocateTensors() {
interpreter_->AllocateTensors();
ctx_ = ctx;
}
void TFLiteRuntime::Invoke() {
......@@ -129,7 +129,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
}
NDArray TFLiteRuntime::GetOutput(int index) const {
TfLiteTensor* output = interpreter_->output_tensor(index);
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[index]);
DataType dtype = TfLiteDType2TVMDType(output->type);
TfLiteIntArray* dims = output->dims;
int64_t size = 1;
......@@ -167,10 +167,6 @@ PackedFunc TFLiteRuntime::GetFunction(
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Invoke();
});
} else if (name == "allocate_tensors") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->AllocateTensors();
});
} else {
return PackedFunc();
}
......
......@@ -36,17 +36,18 @@
namespace tvm {
namespace runtime {
#define CHECK_TFLITE_STATUS(ret) CHECK_EQ(ret, kTfLiteOk)
/*!
* \brief Tflite runtime.
*
* This runtime can be acccesibly in various language via
* This runtime can be accessed in various language via
* TVM runtime PackedFunc API.
*/
class TFLiteRuntime : public ModuleNode {
public:
/*!
* \brief Get member function to front-end
* \brief Get member function to front-end.
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
......@@ -57,15 +58,11 @@ class TFLiteRuntime : public ModuleNode {
/*!
* \return The type key of the executor.
*/
const char* type_key() const final {
const char* type_key() const {
return "TFLiteRuntime";
}
/*!
* \brief Update allocations for all tenssors. This is relatively expensive.
*/
void AllocateTensors();
/*!
* \brief Invoke the internal tflite interpreter and run the whole model in
* dependency order.
*/
......@@ -100,8 +97,9 @@ class TFLiteRuntime : public ModuleNode {
*/
NDArray GetOutput(int index) const;
private:
// TFLite interpreter
std::unique_ptr<tflite::Interpreter> interpreter_;
// TVM context
TVMContext ctx_;
};
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tvm
import numpy as np
from tvm import rpc
from tvm.contrib import util, tflite_runtime
# import tflite_runtime.interpreter as tflite
def skipped_test_tflite_runtime():
def get_tflite_model_path(target_edgetpu):
# Return a path to the model
edgetpu_path = os.getenv('EDGETPU_PATH', "/home/mendel/edgetpu")
# Obtain mobilenet model from the edgetpu repo path
if target_edgetpu:
model_path = os.path.join(edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant_edgetpu.tflite")
else:
model_path = os.path.join(edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant.tflite")
return model_path
def init_interpreter(model_path, target_edgetpu):
# Initialize interpreter
if target_edgetpu:
edgetpu_path = os.getenv('EDGETPU_PATH', "/home/mendel/edgetpu")
libedgetpu = os.path.join(edgetpu_path, "libedgetpu/direct/aarch64/libedgetpu.so.1")
interpreter = tflite.Interpreter(
model_path=model_path,
experimental_delegates=[tflite.load_delegate(libedgetpu)])
else:
interpreter = tflite.Interpreter(model_path=model_path)
return interpreter
def check_remote(target_edgetpu=False):
tflite_model_path = get_tflite_model_path(target_edgetpu)
# inference via tflite interpreter python apis
interpreter = init_interpreter(tflite_model_path, target_edgetpu)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.uint8)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])
# inference via remote tvm tflite runtime
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), ctx)
runtime.set_input(0, tvm.nd.array(tflite_input, ctx))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)
# Target CPU on coral board
check_remote()
# Target EdgeTPU on coral board
check_remote(target_edgetpu=True)
if __name__ == "__main__":
# skipped_test_tflite_runtime()
pass
......@@ -36,16 +36,14 @@ def skipped_test_tflite_runtime():
return tflite_model
def check_verify():
def check_local():
tflite_fname = "model.tflite"
tflite_model = create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
print(tflite_model_path)
open(tflite_model_path, 'wb').write(tflite_model)
# inference via tflite interpreter python apis
print('interpreter')
interpreter = tflite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
......@@ -57,11 +55,9 @@ def skipped_test_tflite_runtime():
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])
print('tvm tflite runtime')
# inference via tvm tflite runtime
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
runtime.allocate_tensors()
runtime.set_input(0, tvm.nd.array(tflite_input))
runtime.invoke()
out = runtime.get_output(0)
......@@ -95,14 +91,12 @@ def skipped_test_tflite_runtime():
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.allocate_tensors()
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)
check_verify()
check_local()
check_remote()
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment