Commit 1de52bb0 by Andrew Tulloch Committed by Tianqi Chen

[RFC] [Contrib] Minimal runtime (~12kb .text on ARMv7/x86) for subset of TVM models (#3567)

This is an alternative implementation of a subset of the TVM runtime API (and
graph runtime) that focuses entirely on reducing code size, at the expense of
functionality (no tvm.extern(..) calls via PackedFunc, CPU only, etc). It might
be worth incrementally expanding the surface area if there's interest.

The motivation for this work was seeing what the minimal useful subset of the
TVM runtime is. This is relevant for e.g. super code-size constrained
applications in e.g. embedded/mobile. The current runtime is more like O(100KiB)
or so, so this might be compelling for some users.

The smaller surface area for auditing might make this relevant for
https://github.com/dmlc/tvm/issues/3159, or the usecases I was thinking about in
https://github.com/dmlc/tvm/issues/2523#issuecomment-459165815 re: the Rust
runtime.

The symbols in the tvm::minimalruntime space (i.e. excluding std:: and
picojson::) are about 5KiB, so I think there's a bunch of room here (i.e. we
could replace picojson:: with [`jsmn`](https://zserge.com/jsmn.html) or
something, and we could replace more of the `std::unordered_map` usage, etc with
custom primitives as well (similar to the `DynArray`).
parent 4e2d707f
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
## Notes
`picojson.h` is derived from https://github.com/kazuho/picojson.
...@@ -46,6 +46,7 @@ tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") ...@@ -46,6 +46,7 @@ tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include")
tvm_option(DMLC_PATH "Path to DMLC" "3rdparty/dmlc-core/include") tvm_option(DMLC_PATH "Path to DMLC" "3rdparty/dmlc-core/include")
tvm_option(RANG_PATH "Path to RANG" "3rdparty/rang/include") tvm_option(RANG_PATH "Path to RANG" "3rdparty/rang/include")
tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt") tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt")
tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
# Contrib library options # Contrib library options
tvm_option(USE_BLAS "The blas library to be linked" none) tvm_option(USE_BLAS "The blas library to be linked" none)
...@@ -57,6 +58,7 @@ tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) ...@@ -57,6 +58,7 @@ tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" OFF) tvm_option(USE_SORT "Build with sort support" OFF)
tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
# include directories # include directories
...@@ -66,6 +68,7 @@ include_directories(${DLPACK_PATH}) ...@@ -66,6 +68,7 @@ include_directories(${DLPACK_PATH})
include_directories(${DMLC_PATH}) include_directories(${DMLC_PATH})
include_directories(${RANG_PATH}) include_directories(${RANG_PATH})
include_directories(${COMPILER_RT_PATH}) include_directories(${COMPILER_RT_PATH})
include_directories(${PICOJSON_PATH})
# initial variables # initial variables
set(TVM_LINKER_LIBS "") set(TVM_LINKER_LIBS "")
...@@ -239,6 +242,7 @@ include(cmake/modules/Micro.cmake) ...@@ -239,6 +242,7 @@ include(cmake/modules/Micro.cmake)
include(cmake/modules/ANTLR.cmake) include(cmake/modules/ANTLR.cmake)
include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake) include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/HybridDump.cmake)
......
...@@ -91,6 +91,9 @@ set(USE_GRAPH_RUNTIME_DEBUG OFF) ...@@ -91,6 +91,9 @@ set(USE_GRAPH_RUNTIME_DEBUG OFF)
# Whether enable additional vm profiler functions # Whether enable additional vm profiler functions
set(USE_VM_PROFILER OFF) set(USE_VM_PROFILER OFF)
# Whether enable uTVM standalone runtime
set(USE_MICRO_STANDALONE_RUNTIME ON)
# Whether build with LLVM support # Whether build with LLVM support
# Requires LLVM version >= 4.0 # Requires LLVM version >= 4.0
# #
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
if(USE_MICRO_STANDALONE_RUNTIME)
message(STATUS "Build with micro.standalone_runtime")
file(GLOB MICRO_STANDALONE_RUNTIME_SRC src/runtime/micro/standalone/*.cc)
list(APPEND RUNTIME_SRCS ${MICRO_STANDALONE_RUNTIME_SRC})
add_definitions(-DUSE_MICRO_STANDALONE_RUNTIME=1)
endif(USE_MICRO_STANDALONE_RUNTIME)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_H_
#define TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_H_
#include <stddef.h>
#include <stdint.h>
#define TVM_MICRO_RUNTIME_API_API extern "C" __attribute__((visibility("default")))
TVM_MICRO_RUNTIME_API_API void* UTVMRuntimeCreate(const char* json, size_t json_len, void* module);
TVM_MICRO_RUNTIME_API_API void UTVMRuntimeDestroy(void* handle);
TVM_MICRO_RUNTIME_API_API void UTVMRuntimeSetInput(void* handle, int index, void* tensor);
TVM_MICRO_RUNTIME_API_API void UTVMRuntimeRun(void* handle);
TVM_MICRO_RUNTIME_API_API void UTVMRuntimeGetOutput(void* handle, int index, void* tensor);
TVM_MICRO_RUNTIME_API_API void* UTVMRuntimeDSOModuleCreate(const char* so, size_t so_len);
TVM_MICRO_RUNTIME_API_API void UTVMRuntimeDSOModuleDestroy(void* module);
#undef TVM_MICRO_RUNTIME_API_API
#endif // TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_H_
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
## A replacement implementation of the TVM runtime, focused on a minimal subset of the overall runtime.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RUNTIME_MICRO_STANDALONE_MINIMAL_VECTOR_H_
#define TVM_RUNTIME_MICRO_STANDALONE_MINIMAL_VECTOR_H_
#include <algorithm>
#include <cassert>
#include <memory>
namespace tvm {
namespace micro {
// A minimal wrapper, derived from https://github.com/Robbepop/dynarray/, that
// supports a minimal subset of the std::vector API with a minimized code size.
template <typename T>
struct DynArray {
using value_type = T;
using size_type = size_t;
using difference_type = std::ptrdiff_t;
using reference = value_type&;
using const_reference = value_type const&;
using pointer = value_type*;
using const_pointer = value_type const*;
using iterator = pointer;
using const_iterator = const_pointer;
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
explicit DynArray(size_type size = 0) { resize(size); }
DynArray(const DynArray& other) {
resize(other.size());
std::copy(other.begin(), other.end(), begin());
}
DynArray& operator=(const DynArray& other) {
resize(other.size());
std::copy(other.begin(), other.end(), begin());
return *this;
}
void resize(size_type size) {
if (size > 0) {
data_.reset(new T[size]);
} else {
data_.reset();
}
size_ = size;
}
size_type size() const { return size_; }
reference operator[](size_type pos) { return data_[pos]; }
const_reference operator[](size_type pos) const { return data_[pos]; }
pointer data() { return data_.get(); }
const_pointer data() const { return data_.get(); }
iterator begin() { return data_.get(); }
const_iterator begin() const { return data_.get(); }
const_iterator cbegin() const { return data_.get(); }
iterator end() { return data_.get() + size_; }
const_iterator end() const { return data_.get() + size_; }
const_iterator cend() const { return data_.get() + size_; }
reference front() { return data_[0]; }
const_reference front() const { return data_[0]; }
reference back() { return data_[size_ - 1]; }
const_reference back() const { return data_[size_ - 1]; }
private:
std::unique_ptr<T[]> data_;
size_type size_;
};
} // namespace micro
} // namespace tvm
#endif // TVM_RUNTIME_MICRO_STANDALONE_MINIMAL_VECTOR_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RUNTIME_MICRO_STANDALONE_UTVM_GRAPH_RUNTIME_H_
#define TVM_RUNTIME_MICRO_STANDALONE_UTVM_GRAPH_RUNTIME_H_
#include <dlpack/dlpack.h>
#include <algorithm>
#include <functional>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "minimal_vector.h"
#include "utvm_runtime_api.h"
namespace tvm {
namespace micro {
typedef int (*BackendPackedCFunc)(void* args, int* type_codes, int num_args);
// dlopen/dlsym/dlclose abstraction.
class DSOModule {
public:
explicit DSOModule(const std::string& name);
~DSOModule();
BackendPackedCFunc GetFunction(const std::string& name) const;
private:
void* GetSymbol(const char* name) const;
void* lib_handle_{nullptr};
};
// The graph attribute fields.
struct GraphAttr {
DynArray<int> storage_id;
DynArray<std::string> dltype;
DynArray<DynArray<int64_t>> shape;
};
// Memory pool entry.
struct PoolEntry {
size_t size;
int device_type;
};
// Node entry
struct NodeEntry {
uint32_t node_id;
uint32_t index;
uint32_t version;
};
// Operator attributes about TVMOp
struct TVMOpParam {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;
};
// Node
struct Node {
// operator type in string
std::string op_type;
// name of the op
std::string name;
// parameters
TVMOpParam param;
// inputs
DynArray<NodeEntry> inputs;
};
// Minimal NDArray abstraction
class NDArray {
public:
// initialize NDArray with shape/dtype/ctx
static NDArray Empty(const DynArray<int64_t>& shape, DLDataType dtype, DLContext ctx);
// create a view of the NDArray storage, with the given shape/dtype
NDArray CreateView(const DynArray<int64_t>& shape, DLDataType dtype);
// Copy into the internal storage.
void CopyFrom(DLTensor* src);
// Copy out of the internal storage
void CopyTo(DLTensor* dst) const;
// View `this` as a DLTensor
DLTensor ToDLTensor();
~NDArray();
private:
// reference-counted storage
std::shared_ptr<void> storage_;
// tensor shape
DynArray<int64_t> shape_;
// tensor dtype
DLDataType dtype_;
// tensor context
DLContext ctx_;
};
// Minimal GraphRuntime implementation
class MicroGraphRuntime {
public:
// Construct a GraphRuntime with the given graph and DSOModule.
MicroGraphRuntime(const std::string& graph_json, DSOModule* module);
~MicroGraphRuntime();
// Run the graph
void Run();
// Set the input at `index` to a copy of the tensor `data_in`
void SetInput(int index, DLTensor* data_in);
// Copy the output at `index` into `data_out`
void CopyOutputTo(int index, DLTensor* data_out);
private:
void SetupStorage();
void SetupOpExecs();
uint32_t num_node_entries() const { return node_row_ptr_.back(); }
uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; }
uint32_t entry_id(const NodeEntry& e) const { return entry_id(e.node_id, e.index); }
DSOModule* module_;
// TODO(tulloch): these are essentially unused after construction.
// The graph nodes
DynArray<Node> nodes_;
// The argument noes
DynArray<uint32_t> input_nodes_;
// Used for quick entry indexing
DynArray<uint32_t> node_row_ptr_;
// Output entries
DynArray<NodeEntry> outputs_;
// Additional graph attributes
GraphAttr attrs_;
// Execution context
DLContext ctx_{kDLCPU, 0};
// Common storage pool
DynArray<NDArray> storage_pool_;
// Data entry for each node
DynArray<NDArray> data_entry_;
// Operator for each node
DynArray<std::function<void()>> op_execs_;
};
} // namespace micro
} // namespace tvm
#endif // TVM_RUNTIME_MICRO_STANDALONE_UTVM_GRAPH_RUNTIME_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <cassert>
#include "tvm/runtime/micro/standalone/utvm_runtime.h"
#include "utvm_graph_runtime.h"
void* UTVMRuntimeCreate(const char* json, size_t json_len, void* module) {
return new tvm::micro::MicroGraphRuntime(
std::string(json, json + json_len),
reinterpret_cast<tvm::micro::DSOModule*>(module));
}
void UTVMRuntimeDestroy(void* handle) {
delete reinterpret_cast<tvm::micro::MicroGraphRuntime*>(handle);
}
void UTVMRuntimeSetInput(void* handle, int index, void* tensor) {
reinterpret_cast<tvm::micro::MicroGraphRuntime*>(handle)->SetInput(
index, reinterpret_cast<DLTensor*>(tensor));
}
void UTVMRuntimeRun(void* handle) {
reinterpret_cast<tvm::micro::MicroGraphRuntime*>(handle)->Run();
}
void UTVMRuntimeGetOutput(void* handle, int index, void* tensor) {
reinterpret_cast<tvm::micro::MicroGraphRuntime*>(handle)->CopyOutputTo(
index, reinterpret_cast<DLTensor*>(tensor));
}
void* UTVMRuntimeDSOModuleCreate(const char* so, size_t so_len) {
return new tvm::micro::DSOModule(std::string(so, so + so_len));
}
void UTVMRuntimeDSOModuleDestroy(void* module) {
delete reinterpret_cast<tvm::micro::DSOModule*>(module);
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "utvm_runtime_api.h"
#include <stdlib.h>
#include <cassert>
#include <string>
void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint,
int dtype_bits_hint) {
void* ptr = nullptr;
assert(nbytes > 0);
#ifdef __ANDROID__
ptr = memalign(64, nbytes);
#else
const int ret = posix_memalign(&ptr, 64, nbytes);
(void)ret;
assert(ret == 0);
#endif
return ptr;
}
int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
free(ptr);
return 0;
}
static thread_local std::string g_last_error;
void TVMAPISetLastError(const char* msg) { g_last_error = msg; }
const char* TVMGetLastError(void) { return g_last_error.c_str(); }
int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) {
TVMParallelGroupEnv env;
env.num_task = 1;
flambda(0, &env, cdata);
return 0;
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_API_H_
#define TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_API_H_
#include <stdint.h>
#include <stdlib.h>
#include <cassert>
// The subset of the TVM runtime API that is implemented by the minimal runtime API.
#define TVM_MICRO_RUNTIME_API_BACKEND_API extern "C" __attribute__((weak, visibility("default")))
TVM_MICRO_RUNTIME_API_BACKEND_API int TVMBackendFreeWorkspace(int device_type, int device_id,
void* ptr);
TVM_MICRO_RUNTIME_API_BACKEND_API void* TVMBackendAllocWorkspace(int device_type, int device_id,
uint64_t nbytes,
int dtype_code_hint,
int dtype_bits_hint);
typedef struct {
void* sync_handle;
int32_t num_task;
} TVMParallelGroupEnv;
typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata);
TVM_MICRO_RUNTIME_API_BACKEND_API int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
void* cdata, int num_task);
TVM_MICRO_RUNTIME_API_BACKEND_API void TVMAPISetLastError(const char* msg);
TVM_MICRO_RUNTIME_API_BACKEND_API const char* TVMGetLastError(void);
#undef TVM_MICRO_RUNTIME_API_BACKEND_API
#endif // TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_API_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <random>
#include <dlpack/dlpack.h>
#include <gtest/gtest.h>
#include <map>
#include <vector>
#ifdef USE_MICRO_STANDALONE_RUNTIME
// Use system(..), `gcc -shared -fPIC`, thus restrict the test to OS X for now.
#if defined(__APPLE__) && defined(__MACH__)
#include <gtest/gtest.h>
#include <topi/generic/injective.h>
#include <tvm/build_module.h>
#include <tvm/operation.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/micro/standalone/utvm_runtime.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <spawn.h>
#include <sys/wait.h>
TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
*rv = topi::generic::schedule_injective(args[0], args[1]);
});
TEST(MicroStandaloneRuntime, BuildModule) {
using namespace tvm;
auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32));
auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type);
auto add_op = relay::Op::Get("add");
auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {});
auto c = relay::VarNode::make("c", tensor_type);
auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {});
auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {});
auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pA = (float*)A.ToDLPack()->dl_tensor.data;
auto pB = (float*)B.ToDLPack()->dl_tensor.data;
auto pC = (float*)C.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
pA[i] = i;
pB[i] = i + 1;
pC[i] = i + 2;
}
// get schedule
auto reg = tvm::runtime::Registry::Get("relay.op._Register");
auto s_i = tvm::runtime::Registry::Get("test.sch");
if (!reg) {
LOG(FATAL) << "no _Register";
}
if (!s_i) {
LOG(FATAL) << "no test_sch";
}
(*reg)("add", "FTVMSchedule", *s_i, 10);
// build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)();
auto build_f = build_mod.GetFunction("build", false);
auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false);
Map<tvm::Integer, tvm::Target> targets;
Target llvm_tgt = Target::Create("llvm");
targets.Set(0, llvm_tgt);
build_f(func, targets, llvm_tgt);
std::string json = json_f();
tvm::runtime::Module mod = mod_f();
std::string o_fname = std::tmpnam(nullptr);
std::string so_fname = std::tmpnam(nullptr);
mod->SaveToFile(o_fname, "o");
const std::vector<std::string> args = {"gcc", "-shared", "-fPIC", "-o", so_fname, o_fname};
std::stringstream s;
for (auto& c : args) {
s << c << " ";
}
const auto ss = s.str();
const auto ret = system(ss.c_str());
ASSERT_EQ(ret, 0);
// Now, execute the minimal runtime.
auto* dsoModule = UTVMRuntimeDSOModuleCreate(so_fname.c_str(), so_fname.size());
ASSERT_NE(dsoModule, nullptr);
auto* handle = UTVMRuntimeCreate(json.c_str(), json.size(), dsoModule);
ASSERT_NE(handle, nullptr);
UTVMRuntimeSetInput(handle, 0, &A.ToDLPack()->dl_tensor);
UTVMRuntimeSetInput(handle, 1, &B.ToDLPack()->dl_tensor);
UTVMRuntimeSetInput(handle, 2, &C.ToDLPack()->dl_tensor);
UTVMRuntimeRun(handle);
auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
UTVMRuntimeGetOutput(handle, 0, &Y.ToDLPack()->dl_tensor);
auto* pY = (float*)Y.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
}
UTVMRuntimeDestroy(handle);
UTVMRuntimeDSOModuleDestroy(dsoModule);
}
#endif
#endif
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment