Commit 31021d2b by Thierry Moreau Committed by Tianqi Chen

[Runtime] EdgeTPU runtime for Coral Boards (#4698)

parent c7a83199
...@@ -154,6 +154,11 @@ set(USE_TFLITE OFF) ...@@ -154,6 +154,11 @@ set(USE_TFLITE OFF)
# /path/to/tensorflow: tensorflow root path when use tflite library # /path/to/tensorflow: tensorflow root path when use tflite library
set(USE_TENSORFLOW_PATH none) set(USE_TENSORFLOW_PATH none)
# Possible values:
# - OFF: disable tflite support for edgetpu
# - /path/to/edgetpu: use specific path to edgetpu library
set(USE_EDGETPU OFF)
# Whether use CuDNN # Whether use CuDNN
set(USE_CUDNN OFF) set(USE_CUDNN OFF)
......
...@@ -25,6 +25,15 @@ if(NOT USE_TFLITE STREQUAL "OFF") ...@@ -25,6 +25,15 @@ if(NOT USE_TFLITE STREQUAL "OFF")
list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC}) list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC})
include_directories(${USE_TENSORFLOW_PATH}) include_directories(${USE_TENSORFLOW_PATH})
# Additional EdgeTPU libs
if (NOT USE_EDGETPU STREQUAL "OFF")
message(STATUS "Build with contrib.edgetpu")
file(GLOB EDGETPU_CONTRIB_SRC src/runtime/contrib/edgetpu/*.cc)
list(APPEND RUNTIME_SRCS ${EDGETPU_CONTRIB_SRC})
include_directories(${USE_EDGETPU}/libedgetpu)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${USE_EDGETPU}/libedgetpu/direct/aarch64/libedgetpu.so.1)
endif()
if (USE_TFLITE STREQUAL "ON") if (USE_TFLITE STREQUAL "ON")
set(USE_TFLITE ${USE_TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib) set(USE_TFLITE ${USE_TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib)
endif() endif()
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from .._ffi.function import get_global_func from .._ffi.function import get_global_func
from ..rpc import base as rpc_base from ..rpc import base as rpc_base
def create(tflite_model_bytes, ctx): def create(tflite_model_bytes, ctx, runtime_target='cpu'):
"""Create a runtime executor module given a tflite model and context. """Create a runtime executor module given a tflite model and context.
Parameters Parameters
---------- ----------
...@@ -27,16 +27,25 @@ def create(tflite_model_bytes, ctx): ...@@ -27,16 +27,25 @@ def create(tflite_model_bytes, ctx):
ctx : TVMContext ctx : TVMContext
The context to deploy the module. It can be local or remote when there The context to deploy the module. It can be local or remote when there
is only one TVMContext. is only one TVMContext.
runtime_target: str
Execution target of TFLite runtime: either `cpu` or `edge_tpu`.
Returns Returns
------- -------
tflite_runtime : TFLiteModule tflite_runtime : TFLiteModule
Runtime tflite module that can be used to execute the tflite model. Runtime tflite module that can be used to execute the tflite model.
""" """
device_type = ctx.device_type device_type = ctx.device_type
if runtime_target == 'edge_tpu':
runtime_func = "tvm.edgetpu_runtime.create"
else:
runtime_func = "tvm.tflite_runtime.create"
if device_type >= rpc_base.RPC_SESS_MASK: if device_type >= rpc_base.RPC_SESS_MASK:
fcreate = ctx._rpc_sess.get_function("tvm.tflite_runtime.create") fcreate = ctx._rpc_sess.get_function(runtime_func)
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx)) else:
fcreate = get_global_func("tvm.tflite_runtime.create") fcreate = get_global_func(runtime_func)
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx)) return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
...@@ -50,12 +59,12 @@ class TFLiteModule(object): ...@@ -50,12 +59,12 @@ class TFLiteModule(object):
Parameters Parameters
---------- ----------
module : Module module : Module
The interal tvm module that holds the actual tflite functions. The internal tvm module that holds the actual tflite functions.
Attributes Attributes
---------- ----------
module : Module module : Module
The interal tvm module that holds the actual tflite functions. The internal tvm module that holds the actual tflite functions.
""" """
def __init__(self, module): def __init__(self, module):
...@@ -63,7 +72,6 @@ class TFLiteModule(object): ...@@ -63,7 +72,6 @@ class TFLiteModule(object):
self._set_input = module["set_input"] self._set_input = module["set_input"]
self._invoke = module["invoke"] self._invoke = module["invoke"]
self._get_output = module["get_output"] self._get_output = module["get_output"]
self._allocate_tensors = module["allocate_tensors"]
def set_input(self, index, value): def set_input(self, index, value):
"""Set inputs to the module via kwargs """Set inputs to the module via kwargs
...@@ -91,12 +99,6 @@ class TFLiteModule(object): ...@@ -91,12 +99,6 @@ class TFLiteModule(object):
""" """
self._invoke() self._invoke()
def allocate_tensors(self):
"""Allocate space for all tensors.
"""
self._allocate_tensors()
def get_output(self, index): def get_output(self, index):
"""Get index-th output to out """Get index-th output to out
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file edgetpu_runtime.cc
*/
#include <tvm/runtime/registry.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h>
#include <edgetpu.h>
#include "edgetpu_runtime.h"
namespace tvm {
namespace runtime {
void EdgeTPURuntime::Init(const std::string& tflite_model_bytes,
TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str();
size_t buffer_size = tflite_model_bytes.size();
// Load compiled model as a FlatBufferModel
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
// Build resolver
tflite::ops::builtin::BuiltinOpResolver resolver;
// Init EdgeTPUContext object
edgetpu_context_ = edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice();
// Add custom edgetpu ops to resolver
resolver.AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp());
// Build interpreter
TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
CHECK_TFLITE_STATUS(status) << "Failed to build interpreter.";
// Bind EdgeTPU context with interpreter.
interpreter_->SetExternalContext(kTfLiteEdgeTpuContext, edgetpu_context_.get());
interpreter_->SetNumThreads(1);
// Allocate tensors
status = interpreter_->AllocateTensors();
CHECK_TFLITE_STATUS(status) << "Failed to allocate tensors.";
ctx_ = ctx;
}
Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes,
TVMContext ctx) {
auto exec = make_object<EdgeTPURuntime>();
exec->Init(tflite_model_bytes, ctx);
return Module(exec);
}
TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = EdgeTPURuntimeCreate(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief EdgeTPU runtime that can run tflite model compiled
* for EdgeTPU containing only tvm PackedFunc.
* \file edgetpu_runtime.h
*/
#ifndef TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
#define TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
#include <string>
#include <memory>
#include "../tflite/tflite_runtime.h"
namespace tvm {
namespace runtime {
/*!
* \brief EdgeTPU runtime.
*
* This runtime can be accessed in various languages via
* the TVM runtime PackedFunc API.
*/
class EdgeTPURuntime : public TFLiteRuntime {
public:
/*!
* \return The type key of the executor.
*/
const char* type_key() const final {
return "EdgeTPURuntime";
}
/*!
* \brief Initialize the edge TPU tflite runtime with tflite model and context.
* \param tflite_model_bytes The tflite model.
* \param ctx The context where the tflite model will be executed on.
*/
void Init(const std::string& tflite_model_bytes,
TVMContext ctx);
private:
std::shared_ptr<edgetpu::EdgeTpuContext> edgetpu_context_;
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
* \file tflite_runtime.cc * \file tflite_runtime.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/dtype.h>
#include <tensorflow/lite/interpreter.h> #include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h> #include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h> #include <tensorflow/lite/model.h>
...@@ -33,37 +32,37 @@ namespace tvm { ...@@ -33,37 +32,37 @@ namespace tvm {
namespace runtime { namespace runtime {
#define TVM_DTYPE_DISPATCH(type, DType, ...) \ #define TVM_DTYPE_DISPATCH(type, DType, ...) \
if (type == DataType::Float(64)) { \ if (type == DataType::Float(64)) { \
typedef double DType; \ typedef double DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == DataType::Float(32)) { \ } else if (type == DataType::Float(32)) { \
typedef float DType; \ typedef float DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == DataType::Float(16)) { \ } else if (type == DataType::Float(16)) { \
typedef uint16_t DType; \ typedef uint16_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == DataType::Int(64)) { \ } else if (type == DataType::Int(64)) { \
typedef int64_t DType; \ typedef int64_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == DataType::Int(32)) { \ } else if (type == DataType::Int(32)) { \
typedef int32_t DType; \ typedef int32_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == DataType::Int(16)) { \ } else if (type == DataType::Int(16)) { \
typedef int16_t DType; \ typedef int16_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == DataType::Int(8)) { \ } else if (type == DataType::Int(8)) { \
typedef int8_t DType; \ typedef int8_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == DataType::UInt(64)) { \ } else if (type == DataType::UInt(64)) { \
typedef uint64_t DType; \ typedef uint64_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == DataType::UInt(32)) { \ } else if (type == DataType::UInt(32)) { \
typedef uint32_t DType; \ typedef uint32_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == DataType::UInt(16)) { \ } else if (type == DataType::UInt(16)) { \
typedef uint16_t DType; \ typedef uint16_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == DataType::UInt(8)) { \ } else if (type == DataType::UInt(8)) { \
typedef uint8_t DType; \ typedef uint8_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else { \ } else { \
...@@ -79,9 +78,9 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { ...@@ -79,9 +78,9 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
case kTfLiteInt64: case kTfLiteInt64:
return DataType::Int(64); return DataType::Int(64);
case kTfLiteInt16: case kTfLiteInt16:
returnDataType::Int(16); return DataType::Int(16);
case kTfLiteInt8: case kTfLiteInt8:
returnDataType::Int(8); return DataType::Int(8);
case kTfLiteUInt8: case kTfLiteUInt8:
return DataType::UInt(8); return DataType::UInt(8);
case kTfLiteFloat16: case kTfLiteFloat16:
...@@ -92,7 +91,6 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { ...@@ -92,7 +91,6 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
} }
} }
void TFLiteRuntime::Init(const std::string& tflite_model_bytes, void TFLiteRuntime::Init(const std::string& tflite_model_bytes,
TVMContext ctx) { TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str(); const char* buffer = tflite_model_bytes.c_str();
...@@ -100,12 +98,14 @@ void TFLiteRuntime::Init(const std::string& tflite_model_bytes, ...@@ -100,12 +98,14 @@ void TFLiteRuntime::Init(const std::string& tflite_model_bytes,
std::unique_ptr<tflite::FlatBufferModel> model = std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
tflite::ops::builtin::BuiltinOpResolver resolver; tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model, resolver)(&interpreter_); // Build interpreter
ctx_ = ctx; TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
} CHECK_TFLITE_STATUS(status) << "Failed to build interpreter.";
// Allocate tensors
status = interpreter_->AllocateTensors();
CHECK_TFLITE_STATUS(status) << "Failed to allocate tensors.";
void TFLiteRuntime::AllocateTensors() { ctx_ = ctx;
interpreter_->AllocateTensors();
} }
void TFLiteRuntime::Invoke() { void TFLiteRuntime::Invoke() {
...@@ -129,7 +129,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { ...@@ -129,7 +129,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
} }
NDArray TFLiteRuntime::GetOutput(int index) const { NDArray TFLiteRuntime::GetOutput(int index) const {
TfLiteTensor* output = interpreter_->output_tensor(index); TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[index]);
DataType dtype = TfLiteDType2TVMDType(output->type); DataType dtype = TfLiteDType2TVMDType(output->type);
TfLiteIntArray* dims = output->dims; TfLiteIntArray* dims = output->dims;
int64_t size = 1; int64_t size = 1;
...@@ -167,10 +167,6 @@ PackedFunc TFLiteRuntime::GetFunction( ...@@ -167,10 +167,6 @@ PackedFunc TFLiteRuntime::GetFunction(
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Invoke(); this->Invoke();
}); });
} else if (name == "allocate_tensors") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->AllocateTensors();
});
} else { } else {
return PackedFunc(); return PackedFunc();
} }
......
...@@ -36,17 +36,18 @@ ...@@ -36,17 +36,18 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
#define CHECK_TFLITE_STATUS(ret) CHECK_EQ(ret, kTfLiteOk)
/*! /*!
* \brief Tflite runtime. * \brief Tflite runtime.
* *
* This runtime can be acccesibly in various language via * This runtime can be accessed in various language via
* TVM runtime PackedFunc API. * TVM runtime PackedFunc API.
*/ */
class TFLiteRuntime : public ModuleNode { class TFLiteRuntime : public ModuleNode {
public: public:
/*! /*!
* \brief Get member function to front-end * \brief Get member function to front-end.
* \param name The name of the function. * \param name The name of the function.
* \param sptr_to_self The pointer to the module node. * \param sptr_to_self The pointer to the module node.
* \return The corresponding member function. * \return The corresponding member function.
...@@ -57,15 +58,11 @@ class TFLiteRuntime : public ModuleNode { ...@@ -57,15 +58,11 @@ class TFLiteRuntime : public ModuleNode {
/*! /*!
* \return The type key of the executor. * \return The type key of the executor.
*/ */
const char* type_key() const final { const char* type_key() const {
return "TFLiteRuntime"; return "TFLiteRuntime";
} }
/*! /*!
* \brief Update allocations for all tenssors. This is relatively expensive.
*/
void AllocateTensors();
/*!
* \brief Invoke the internal tflite interpreter and run the whole model in * \brief Invoke the internal tflite interpreter and run the whole model in
* dependency order. * dependency order.
*/ */
...@@ -100,8 +97,9 @@ class TFLiteRuntime : public ModuleNode { ...@@ -100,8 +97,9 @@ class TFLiteRuntime : public ModuleNode {
*/ */
NDArray GetOutput(int index) const; NDArray GetOutput(int index) const;
private: // TFLite interpreter
std::unique_ptr<tflite::Interpreter> interpreter_; std::unique_ptr<tflite::Interpreter> interpreter_;
// TVM context
TVMContext ctx_; TVMContext ctx_;
}; };
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tvm
import numpy as np
from tvm import rpc
from tvm.contrib import util, tflite_runtime
# import tflite_runtime.interpreter as tflite
def skipped_test_tflite_runtime():
def get_tflite_model_path(target_edgetpu):
# Return a path to the model
edgetpu_path = os.getenv('EDGETPU_PATH', "/home/mendel/edgetpu")
# Obtain mobilenet model from the edgetpu repo path
if target_edgetpu:
model_path = os.path.join(edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant_edgetpu.tflite")
else:
model_path = os.path.join(edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant.tflite")
return model_path
def init_interpreter(model_path, target_edgetpu):
# Initialize interpreter
if target_edgetpu:
edgetpu_path = os.getenv('EDGETPU_PATH', "/home/mendel/edgetpu")
libedgetpu = os.path.join(edgetpu_path, "libedgetpu/direct/aarch64/libedgetpu.so.1")
interpreter = tflite.Interpreter(
model_path=model_path,
experimental_delegates=[tflite.load_delegate(libedgetpu)])
else:
interpreter = tflite.Interpreter(model_path=model_path)
return interpreter
def check_remote(target_edgetpu=False):
tflite_model_path = get_tflite_model_path(target_edgetpu)
# inference via tflite interpreter python apis
interpreter = init_interpreter(tflite_model_path, target_edgetpu)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.uint8)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])
# inference via remote tvm tflite runtime
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), ctx)
runtime.set_input(0, tvm.nd.array(tflite_input, ctx))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)
# Target CPU on coral board
check_remote()
# Target EdgeTPU on coral board
check_remote(target_edgetpu=True)
if __name__ == "__main__":
# skipped_test_tflite_runtime()
pass
...@@ -36,16 +36,14 @@ def skipped_test_tflite_runtime(): ...@@ -36,16 +36,14 @@ def skipped_test_tflite_runtime():
return tflite_model return tflite_model
def check_verify(): def check_local():
tflite_fname = "model.tflite" tflite_fname = "model.tflite"
tflite_model = create_tflite_model() tflite_model = create_tflite_model()
temp = util.tempdir() temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname) tflite_model_path = temp.relpath(tflite_fname)
print(tflite_model_path)
open(tflite_model_path, 'wb').write(tflite_model) open(tflite_model_path, 'wb').write(tflite_model)
# inference via tflite interpreter python apis # inference via tflite interpreter python apis
print('interpreter')
interpreter = tflite.Interpreter(model_path=tflite_model_path) interpreter = tflite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors() interpreter.allocate_tensors()
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
...@@ -57,11 +55,9 @@ def skipped_test_tflite_runtime(): ...@@ -57,11 +55,9 @@ def skipped_test_tflite_runtime():
interpreter.invoke() interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index']) tflite_output = interpreter.get_tensor(output_details[0]['index'])
print('tvm tflite runtime')
# inference via tvm tflite runtime # inference via tvm tflite runtime
with open(tflite_model_path, 'rb') as model_fin: with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
runtime.allocate_tensors()
runtime.set_input(0, tvm.nd.array(tflite_input)) runtime.set_input(0, tvm.nd.array(tflite_input))
runtime.invoke() runtime.invoke()
out = runtime.get_output(0) out = runtime.get_output(0)
...@@ -95,14 +91,12 @@ def skipped_test_tflite_runtime(): ...@@ -95,14 +91,12 @@ def skipped_test_tflite_runtime():
with open(tflite_model_path, 'rb') as model_fin: with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.allocate_tensors()
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke() runtime.invoke()
out = runtime.get_output(0) out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output) np.testing.assert_equal(out.asnumpy(), tflite_output)
check_local()
check_verify()
check_remote() check_remote()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment