Unverified Commit 450f7163 by Liangfu Chen Committed by GitHub

[Runtime] MISRA-C compliant TVM runtime (#3934)

* implement of MISRA-C compliant TVM runtime;

* working on bundle_deploy_c demo

* move header files into include dir

* fix compatibility issues

* fix compatibility issues

* resolve most of the warnings and errros

* implement c_backend_api

* introduce bridge

* working well

* move to header files and bundle.c into src/runtime/crt

* clean up

* satisfy linter

* clean up

* test with the cat image

* remove synset

* refactoring

* refactoring

* refactoring

* initial crt_runtime_api.c

* improved compatibility with g++

* using exposed API in c_runtime_api.h

* call from c_runtime_api.h

* clean up

* lint

* merge into apps/bundle_deploy directory

Change-Id: I51904db81b8589e65d107d8ca77b47452e3812b5

* make the demo runs in ci

Change-Id: I2c24f8b592508833d3555311c2b24d1931f19385

* address review comments

Change-Id: I027ddff15c31fb4da0bd0e461427dce619de1f93

* release

Change-Id: I5ad5bb8426468aac9fc8d074e56ddea358a7fd91

* fix ci testing

Change-Id: Ic2e82fb3051b6c254ef32a964f976b61e3e5fe4d

* add test case for misra c runtime

Change-Id: Ie0dfd0ade6be4665b4384db7d260a6c69b35010f

* fread files in testing to avoid calling xxd

Change-Id: Ie7fbc16b4b0b9509918d986a841f443900813bef
parent 5b4cf5df
......@@ -17,40 +17,80 @@
# Makefile Example to bundle TVM modules.
# Setup build environment
TVM_ROOT=$(shell cd ../..; pwd)
DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core
PKG_CFLAGS = -std=c++14 -O2 -fPIC\
-I${TVM_ROOT}/include\
-I${DMLC_CORE}/include\
PKG_CXXFLAGS = -std=c++14 -O2 -fPIC \
-I${TVM_ROOT}/include \
-I${DMLC_CORE}/include \
-I${TVM_ROOT}/3rdparty/dlpack/include
PKG_CFLAGS = -std=c99 -O2 -fPIC \
-I${TVM_ROOT}/include \
-I${DMLC_CORE}/include \
-I${TVM_ROOT}/3rdparty/dlpack/include
PKG_LDFLAGS = -pthread
build_dir := build
test: $(build_dir)/demo $(build_dir)/bundle.so
$(build_dir)/demo $(build_dir)/bundle.so
demo: $(build_dir)/demo $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/cat.bin
TVM_NUM_THREADS=1 $(build_dir)/demo $(build_dir)/bundle.so $(build_dir)/cat.bin
TVM_NUM_THREADS=1 $(build_dir)/demo $(build_dir)/bundle_c.so $(build_dir)/cat.bin
test: $(build_dir)/test $(build_dir)/test_bundle.so $(build_dir)/test_bundle_c.so $(build_dir)/test_data.bin $(build_dir)/test_output.bin
TVM_NUM_THREADS=1 $(build_dir)/test $(build_dir)/test_bundle.so $(build_dir)/test_data.bin $(build_dir)/test_output.bin $(build_dir)/test_graph.json $(build_dir)/test_params.bin
TVM_NUM_THREADS=1 $(build_dir)/test $(build_dir)/test_bundle_c.so $(build_dir)/test_data.bin $(build_dir)/test_output.bin $(build_dir)/test_graph.json $(build_dir)/test_params.bin
$(build_dir)/demo: demo.cc ${build_dir}/graph.json.c ${build_dir}/params.bin.c
@mkdir -p $(@D)
g++ $(PKG_CXXFLAGS) -o $@ demo.cc -ldl
$(build_dir)/demo: demo.cc
$(build_dir)/test: test.cc ${build_dir}/test_graph.json ${build_dir}/test_params.bin
@mkdir -p $(@D)
$(CXX) $(PKG_CFLAGS) -o $@ $^ -ldl
g++ $(PKG_CXXFLAGS) -o $@ test.cc -ldl
# Serialize our graph.json file.
$(build_dir)/graph.json.cc: $(build_dir)/graph.json
$(build_dir)/graph.json.c: $(build_dir)/graph.json
xxd -i $^ > $@
# Serialize our params.bin file.
$(build_dir)/params.bin.cc: $(build_dir)/params.bin
$(build_dir)/params.bin.c: $(build_dir)/params.bin
xxd -i $^ > $@
$(build_dir)/model.o $(build_dir)/graph.json $(build_dir)/params.bin: build_model.py
# # Serialize our test_graph.json file.
# $(build_dir)/test_graph.json.c: $(build_dir)/test_graph.json
# xxd -i $^ > $@
#
# # Serialize our test_params.bin file.
# $(build_dir)/test_params.bin.c: $(build_dir)/test_params.bin
# xxd -i $^ > $@
$(build_dir)/model.o $(build_dir)/graph.json $(build_dir)/params.bin $(build_dir)/cat.bin: build_model.py
python3 $< -o $(build_dir)
# Build our bundle against the serialized bundle.cc API, the runtime.cc API, and
$(build_dir)/test_model.o $(build_dir)/test_graph.json $(build_dir)/test_params.bin $(build_dir)/test_data.bin $(build_dir)/test_output.bin: build_model.py
python3 $< -o $(build_dir) --test
# Build our bundle against the serialized bundle.c API, the runtime.cc API, and
# the serialized graph.json and params.bin
$(build_dir)/bundle.so: bundle.cc runtime.cc $(build_dir)/model.o $(build_dir)/graph.json.cc $(build_dir)/params.bin.cc
$(build_dir)/bundle.so: bundle.cc runtime.cc $(build_dir)/model.o
@mkdir -p $(@D)
$(CXX) -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS)
g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS)
$(build_dir)/bundle_c.so: bundle.c runtime.c $(build_dir)/model.o
@mkdir -p $(@D)
gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS)
$(build_dir)/test_bundle.so: bundle.cc runtime.cc $(build_dir)/test_model.o
@mkdir -p $(@D)
g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS)
$(build_dir)/test_bundle_c.so: bundle.c runtime.c $(build_dir)/test_model.o
@mkdir -p $(@D)
gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS)
clean:
rm -r $(build_dir)
rm -rf $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/test_bundle.so $(build_dir)/test_bundle_c.so
cleanall:
rm -rf $(build_dir)
......@@ -45,9 +45,10 @@ make demo
This will:
- Download the mobilenet0.25 model from the MXNet Gluon Model Zoo
- Compile the model with NNVM
- Compile the model with Relay
- Build a `bundle.so` shared object containing the model specification and
parameters
- Build a `demo` executable that `dlopen`'s `bundle.so`, instantiates the
contained graph runtime, and invokes the `GraphRuntime::Run` function on a
random input, then prints the output tensor to `stderr`.
- Build a `demo` executable that `dlopen`'s `bundle.so` (or `bundle_c.so` in
terms of the MISRA-C runtime), instantiates the contained graph runtime,
and invokes the `GraphRuntime::Run` function on a cat image, then prints
the output results.
......@@ -22,15 +22,9 @@ from tvm import relay
import tvm
from tvm import te
import logging
import json
def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--out-dir', default='.')
opts = parser.parse_args()
def build_module(opts):
dshape = (1, 3, 224, 224)
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('mobilenet0.25', pretrained=True)
......@@ -53,6 +47,69 @@ def main():
with open(os.path.join(build_dir, 'params.bin'), 'wb') as f_params:
f_params.write(relay.save_param_dict(params))
def build_test_module(opts):
import numpy as np
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5))
z = relay.add(x, y)
func = relay.Function([x, y], z)
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32')
params = {"y": y_data}
graph, lib, params = relay.build(
tvm.IRModule.from_expr(func), "llvm --system-lib", params=params)
build_dir = os.path.abspath(opts.out_dir)
if not os.path.isdir(build_dir):
os.makedirs(build_dir)
lib.save(os.path.join(build_dir, 'test_model.o'))
with open(os.path.join(build_dir, 'test_graph.json'), 'w') as f_graph_json:
f_graph_json.write(graph)
with open(os.path.join(build_dir, 'test_params.bin'), 'wb') as f_params:
f_params.write(relay.save_param_dict(params))
with open(os.path.join(build_dir, "test_data.bin"), "wb") as fp:
fp.write(x_data.astype(np.float32).tobytes())
x_output = x_data + y_data
with open(os.path.join(build_dir, "test_output.bin"), "wb") as fp:
fp.write(x_output.astype(np.float32).tobytes())
def build_inputs(opts):
from tvm.contrib import download
from PIL import Image
import numpy as np
build_dir = os.path.abspath(opts.out_dir)
# Download test image
image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
image_fn = os.path.join(build_dir, "cat.png")
download.download(image_url, image_fn)
image = Image.open(image_fn).resize((224, 224))
def transform_image(image):
image = np.array(image) - np.array([123., 117., 104.])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return image
x = transform_image(image)
print('x', x.shape)
with open(os.path.join(build_dir, "cat.bin"), "wb") as fp:
fp.write(x.astype(np.float32).tobytes())
if __name__ == '__main__':
main()
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--out-dir', default='.')
parser.add_argument('-t', '--test', action='store_true')
opts = parser.parse_args()
if opts.test:
build_test_module(opts)
else:
build_module(opts)
build_inputs(opts)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/c_runtime_api.h>
#include <stdio.h>
#include <stdlib.h>
/*! \brief macro to do C API call */
#define TVM_CCALL(func) \
do { \
int ret = (func); \
if (ret != 0) { \
fprintf(stderr, "%s: %d: error: %s\n", __FILE__, __LINE__, TVMGetLastError()); \
exit(ret); \
} \
} while (0)
TVM_DLL void * tvm_runtime_create(const char * json_data,
const char * params_data,
const uint64_t params_size) {
int64_t device_type = kDLCPU;
int64_t device_id = 0;
TVMByteArray params;
params.data = params_data;
params.size = params_size;
TVMContext ctx;
ctx.device_type = (DLDeviceType)device_type;
ctx.device_id = device_id;
// declare pointers
TVMModuleHandle (*SystemLibraryCreate)();
TVMModuleHandle (*TVMGraphRuntimeCreate)(const char *, const TVMModuleHandle, const TVMContext *);
int (*TVMGraphRuntime_LoadParams)(TVMModuleHandle, const char *, const uint32_t);
// get pointers
TVM_CCALL(TVMFuncGetGlobal("runtime.SystemLib", (TVMFunctionHandle*)&SystemLibraryCreate));
TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.create", (TVMFunctionHandle*)&TVMGraphRuntimeCreate));
// run modules
TVMModuleHandle mod_syslib = SystemLibraryCreate();
TVMModuleHandle mod = TVMGraphRuntimeCreate(json_data, mod_syslib, &ctx);
TVM_CCALL(TVMModGetFunction(mod, "load_params", 0, (TVMFunctionHandle*)&TVMGraphRuntime_LoadParams));
TVMGraphRuntime_LoadParams(mod, params.data, params.size);
return mod;
}
TVM_DLL void tvm_runtime_destroy(void * runtime) {
void (*TVMGraphRuntimeRelease)(TVMModuleHandle *);
TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.release", (TVMFunctionHandle*)&TVMGraphRuntimeRelease));
TVMGraphRuntimeRelease(&runtime);
}
TVM_DLL void tvm_runtime_set_input(void * runtime, const char * name, DLTensor * tensor) {
void (*TVMGraphRuntime_SetInput)(TVMModuleHandle, const char *, DLTensor*);
TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.set_input", (TVMFunctionHandle*)&TVMGraphRuntime_SetInput));
TVMGraphRuntime_SetInput(runtime, name, tensor);
}
TVM_DLL void tvm_runtime_run(void * runtime) {
void (*TVMGraphRuntime_Run)(TVMModuleHandle runtime);
TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.run", (TVMFunctionHandle*)&TVMGraphRuntime_Run));
TVMGraphRuntime_Run(runtime);
}
TVM_DLL void tvm_runtime_get_output(void * runtime, int32_t index, DLTensor * tensor) {
int (*TVMGraphRuntime_GetOutput)(TVMModuleHandle, const int32_t, DLTensor *);
TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.get_output", (TVMFunctionHandle*)&TVMGraphRuntime_GetOutput));
TVMGraphRuntime_GetOutput(runtime, index, tensor);
}
......@@ -21,16 +21,14 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
extern unsigned char build_graph_json[];
extern unsigned int build_graph_json_len;
extern unsigned char build_params_bin[];
extern unsigned int build_params_bin_len;
#define TVM_BUNDLE_FUNCTION __attribute__((visibility("default")))
extern "C" {
TVM_BUNDLE_FUNCTION void *tvm_runtime_create() {
TVM_BUNDLE_FUNCTION void *tvm_runtime_create(const char * build_graph_json,
const char * build_params_bin,
const uint64_t build_params_bin_len) {
const int build_graph_json_len = strlen(build_graph_json);
const std::string json_data(&build_graph_json[0],
&build_graph_json[0] + build_graph_json_len);
tvm::runtime::Module mod_syslib =
......
......@@ -17,13 +17,17 @@
* under the License.
*/
#include "tvm/runtime/c_runtime_api.h"
#include <tvm/runtime/c_runtime_api.h>
#include <assert.h>
#include <dlfcn.h> //dlopen
#include <dlpack/dlpack.h>
#include <iostream>
#include <random>
#include <vector>
#include <sys/time.h>
#include "build/graph.json.c"
#include "build/params.bin.c"
template <typename F> auto getFunc(void *bundle, const char *name) {
dlerror();
......@@ -34,39 +38,50 @@ template <typename F> auto getFunc(void *bundle, const char *name) {
}
int main(int argc, char **argv) {
assert(argc == 2 && "Usage: demo <bundle.so>");
assert(argc == 3 && "Usage: demo <bundle.so> <cat.bin>");
auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL);
assert(bundle);
auto *handle = getFunc<void *()>(bundle, "tvm_runtime_create")();
char * json_data = reinterpret_cast<char*>(build_graph_json);
char * params_data = reinterpret_cast<char*>(build_params_bin);
uint64_t params_size = build_params_bin_len;
std::vector<float> input_storage(1 * 3 * 224 * 224);
std::mt19937 gen(0);
for (auto &e : input_storage) {
e = std::uniform_real_distribution<float>(0.0, 1.0)(gen);
}
struct timeval t0, t1, t2, t3, t4, t5;
gettimeofday(&t0, 0);
auto *handle = getFunc<void *(char*, char*, int)>(bundle, "tvm_runtime_create")(
json_data, params_data, params_size);
gettimeofday(&t1, 0);
float input_storage[1 * 3 * 224 * 224];
FILE * fp = fopen(argv[2], "rb");
fread(input_storage, 3 * 224 * 224, 4, fp);
fclose(fp);
std::vector<int64_t> input_shape = {1, 3, 224, 224};
DLTensor input;
input.data = input_storage.data();
input.data = input_storage;
input.ctx = DLContext{kDLCPU, 0};
input.ndim = 4;
input.dtype = DLDataType{kDLFloat, 32, 1};
input.shape = input_shape.data();
input.strides = nullptr;
input.byte_offset = 0;
getFunc<void(void *, const char *, void *)>(bundle, "tvm_runtime_set_input")(
handle, "data", &input);
gettimeofday(&t2, 0);
auto *ftvm_runtime_run =
(auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run");
assert(!dlerror());
ftvm_runtime_run(handle);
gettimeofday(&t3, 0);
std::vector<float> output_storage(1000);
float output_storage[1000];
std::vector<int64_t> output_shape = {1, 1000};
DLTensor output;
output.data = output_storage.data();
output.data = output_storage;
output.ctx = DLContext{kDLCPU, 0};
output.ndim = 2;
output.dtype = DLDataType{kDLFloat, 32, 1};
......@@ -76,10 +91,30 @@ int main(int argc, char **argv) {
getFunc<void(void *, int, void *)>(bundle, "tvm_runtime_get_output")(
handle, 0, &output);
for (auto i = 0; i < output_storage.size(); ++i) {
std::cerr << "output[" << i << "]: " << output_storage[i] << std::endl;
gettimeofday(&t4, 0);
float max_iter = -std::numeric_limits<float>::max();
int32_t max_index = -1;
for (auto i = 0; i < 1000; ++i) {
if (output_storage[i] > max_iter) {
max_iter = output_storage[i];
max_index = i;
}
}
getFunc<void(void *)>(bundle, "tvm_runtime_destroy")(handle);
gettimeofday(&t5, 0);
printf("The maximum position in output vector is: %d, with max-value %f.\n",
max_index, max_iter);
printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), "
"%.2f ms (get_output), %.2f ms (destroy)\n",
(t1.tv_sec-t0.tv_sec)*1000000 + (t1.tv_usec-t0.tv_usec)/1000.f,
(t2.tv_sec-t1.tv_sec)*1000000 + (t2.tv_usec-t1.tv_usec)/1000.f,
(t3.tv_sec-t2.tv_sec)*1000000 + (t3.tv_usec-t2.tv_usec)/1000.f,
(t4.tv_sec-t3.tv_sec)*1000000 + (t4.tv_usec-t3.tv_usec)/1000.f,
(t5.tv_sec-t4.tv_sec)*1000000 + (t5.tv_usec-t4.tv_usec)/1000.f);
dlclose(bundle);
return 0;
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/* Explicitly declare posix_memalign function */
#if _POSIX_C_SOURCE < 200112L
#undef _POSIX_C_SOURCE
#define _POSIX_C_SOURCE 200809L
#endif
/*! Support low-level debugging in MISRA-C runtime */
#define TVM_CRT_DEBUG 0
/*! Maximum supported dimension in NDArray */
#define TVM_CRT_MAX_NDIM 6
/*! Maximum supported arguments in generated functions */
#define TVM_CRT_MAX_ARGS 10
/*! Maximum inputs in a GraphRuntimeNode */
#define GRAPH_RUNTIME_NODE_MAX_INPUTS 300
/*! Maximum supported contexts in a GraphRuntime */
#define GRAPH_RUNTIME_MAX_CONTEXTS 1
/*! Maximum supported nodes in a GraphRuntime */
#define GRAPH_RUNTIME_MAX_NODES 400
/*! Maximum input nodes in a GraphRuntime */
#define GRAPH_RUNTIME_MAX_INPUT_NODES 300
/*! Maximum nodes in a GraphRuntime for quick entry indexing */
#define GRAPH_RUNTIME_MAX_NODE_ROW_PTR 300
/*! Maximum output entries in a GraphRuntime */
#define GRAPH_RUNTIME_MAX_OUTPUTS 300
#include "../../src/runtime/crt/crt_runtime_api.c"
#include "../../src/runtime/crt/crt_backend_api.c"
#include "../../src/runtime/crt/graph_runtime.c"
#include "../../src/runtime/crt/load_json.c"
#include "../../src/runtime/crt/ndarray.c"
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/c_runtime_api.h>
#include <assert.h>
#include <dlfcn.h> //dlopen
#include <iostream>
#include <random>
#include <vector>
#include <sys/time.h>
#include <sys/stat.h>
template <typename F> auto getFunc(void *bundle, const char *name) {
dlerror();
auto *f =
reinterpret_cast<typename std::add_pointer<F>::type>(dlsym(bundle, name));
assert(!dlerror());
return f;
}
int main(int argc, char **argv) {
assert(argc == 6 && "Usage: test <bundle.so> <data.bin> <output.bin> <graph.json> <params.bin>");
auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL);
assert(bundle);
struct stat st;
char * json_data;
char * params_data;
uint64_t params_size;
FILE * fp = fopen(argv[4], "rb");
stat(argv[4], &st);
json_data = (char*)malloc(st.st_size);
fread(json_data, st.st_size, 1, fp);
fclose(fp);
fp = fopen(argv[5], "rb");
stat(argv[5], &st);
params_data = (char*)malloc(st.st_size);
fread(params_data, st.st_size, 1, fp);
params_size = st.st_size;
fclose(fp);
struct timeval t0, t1, t2, t3, t4, t5;
gettimeofday(&t0, 0);
auto *handle = getFunc<void *(char*, char*, int)>(bundle, "tvm_runtime_create")(
json_data, params_data, params_size);
gettimeofday(&t1, 0);
float input_storage[10 * 5];
fp = fopen(argv[2], "rb");
fread(input_storage, 10 * 5, 4, fp);
fclose(fp);
float result_storage[10 * 5];
fp = fopen(argv[3], "rb");
fread(result_storage, 10 * 5, 4, fp);
fclose(fp);
std::vector<int64_t> input_shape = {10, 5};
DLTensor input;
input.data = input_storage;
input.ctx = DLContext{kDLCPU, 0};
input.ndim = 2;
input.dtype = DLDataType{kDLFloat, 32, 1};
input.shape = input_shape.data();
input.strides = nullptr;
input.byte_offset = 0;
getFunc<void(void *, const char *, void *)>(bundle, "tvm_runtime_set_input")(
handle, "x", &input);
gettimeofday(&t2, 0);
auto *ftvm_runtime_run =
(auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run");
assert(!dlerror());
ftvm_runtime_run(handle);
gettimeofday(&t3, 0);
float output_storage[10 * 5];
std::vector<int64_t> output_shape = {10, 5};
DLTensor output;
output.data = output_storage;
output.ctx = DLContext{kDLCPU, 0};
output.ndim = 2;
output.dtype = DLDataType{kDLFloat, 32, 1};
output.shape = output_shape.data();
output.strides = nullptr;
output.byte_offset = 0;
getFunc<void(void *, int, void *)>(bundle, "tvm_runtime_get_output")(
handle, 0, &output);
gettimeofday(&t4, 0);
for (auto i = 0; i < 10 * 5; ++i) {
assert(fabs(output_storage[i] - result_storage[i]) < 1e-5f);
if (fabs(output_storage[i] - result_storage[i]) >= 1e-5f) {
printf("got %f, expected %f\n", output_storage[i], result_storage[i]);
}
}
getFunc<void(void *)>(bundle, "tvm_runtime_destroy")(handle);
gettimeofday(&t5, 0);
printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), "
"%.2f ms (get_output), %.2f ms (destroy)\n",
(t1.tv_sec-t0.tv_sec)*1000000 + (t1.tv_usec-t0.tv_usec)/1000.f,
(t2.tv_sec-t1.tv_sec)*1000000 + (t2.tv_usec-t1.tv_usec)/1000.f,
(t3.tv_sec-t2.tv_sec)*1000000 + (t3.tv_usec-t2.tv_usec)/1000.f,
(t4.tv_sec-t3.tv_sec)*1000000 + (t4.tv_usec-t3.tv_usec)/1000.f,
(t5.tv_sec-t4.tv_sec)*1000000 + (t5.tv_usec-t4.tv_usec)/1000.f);
free(json_data);
free(params_data);
dlclose(bundle);
return 0;
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/c_backend_api.h>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <string.h>
void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint,
int dtype_bits_hint) {
void* ptr = 0;
assert(nbytes > 0);
unsigned int dtype_bytes = dtype_bits_hint / 8;
#ifdef __ANDROID__
ptr = memalign(64, nbytes * dtype_bytes);
#else
const int ret = posix_memalign(&ptr, 64, nbytes * dtype_bytes);
(void)ret;
assert(ret == 0);
#endif
return ptr;
}
int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
free(ptr);
return 0;
}
int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) {
TVMParallelGroupEnv env;
env.num_task = 1;
flambda(0, &env, cdata);
return 0;
}
int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) {
snprintf(g_fexecs[g_fexecs_count].name, sizeof(g_fexecs[g_fexecs_count].name), name);
g_fexecs[g_fexecs_count].fexec = ptr;
g_fexecs_count++;
return 0;
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/c_runtime_api.h>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <string.h>
#include "ndarray.h"
#include "graph_runtime.h"
#include "packed_func.h"
// Handle internal errors
static char g_last_error[1024];
void TVMAPISetLastError(const char* msg) {
assert(strlen(msg) < sizeof(g_last_error));
snprintf(g_last_error, sizeof(g_last_error), "%s", msg);
}
const char* TVMGetLastError(void) { return g_last_error; }
// Manipulate NDArray on target device
int TVMArrayAlloc(const tvm_index_t* shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
TVMArrayHandle* out) {
DLDataType dtype;
dtype.code = dtype_code;
dtype.bits = dtype_bits;
dtype.lanes = dtype_lanes;
DLContext ctx;
ctx.device_type = (DLDeviceType)device_type;
ctx.device_id = device_id;
TVMNDArray arr = TVMNDArray_Empty(ndim, shape, dtype, ctx);
**out = arr.dl_tensor;
return 0;
}
int TVMArrayFree(TVMArrayHandle handle) {
TVMNDArray arr;
arr.dl_tensor = *handle;
return TVMNDArray_Release(&arr);
}
void * SystemLibraryCreate() {
return 0;
}
int TVMModGetFunction(TVMModuleHandle mod,
const char* func_name,
int query_imports,
TVMFunctionHandle *out) {
int status = 0;
if (!strcmp(func_name, "load_params")) {
*out = &TVMGraphRuntime_LoadParams;
} else {
status -1;
}
return status;
}
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
int status = 0;
if (!strcmp(name, "tvm.graph_runtime.create")) {
*out = &TVMGraphRuntimeCreate;
} else if (!strcmp(name, "tvm.graph_runtime.set_input")) {
*out = &TVMGraphRuntime_SetInput;
} else if (!strcmp(name, "tvm.graph_runtime.run")) {
*out = &TVMGraphRuntime_Run;
} else if (!strcmp(name, "tvm.graph_runtime.get_output")) {
*out = &TVMGraphRuntime_GetOutput;
} else if (!strcmp(name, "tvm.graph_runtime.release")) {
*out = &TVMGraphRuntimeRelease;
} else if (!strcmp(name, "runtime.SystemLib")) {
*out = &SystemLibraryCreate;
} else {
char msg[200];
snprintf(msg, sizeof(msg), "fail to get global: name=%s", name);
TVMAPISetLastError(msg);
status = -1;
}
return status;
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file graph_runtime.h
* \brief Tiny graph runtime that can run graph containing only tvm PackedFunc.
*/
#ifndef TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_
#define TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_
#include <dlpack/dlpack.h>
#include "load_json.h"
#include "ndarray.h"
#include "packed_func.h"
#include "module.h"
/*! \brief operator attributes about tvm op */
typedef struct TVMOpParam {
char func_name[120];
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;
} TVMOpParam;
// Memory pool entry.
typedef struct TVMGraphRuntimePoolEntry {
size_t size;
int device_type;
} TVMGraphRuntimePoolEntry;
// Node entry
typedef struct TVMGraphRuntimeNodeEntry {
uint32_t node_id;
uint32_t index;
uint32_t version;
// JSON Loader
void (*Load)(JSONReader *reader);
} TVMGraphRuntimeNodeEntry;
// Node
typedef struct TVMGraphRuntimeNode {
// operator type in string
char op_type[16];
// name of the op
char name[120];
// parameters
TVMOpParam param;
// inputs
TVMGraphRuntimeNodeEntry inputs[GRAPH_RUNTIME_NODE_MAX_INPUTS];
size_t inputs_count;
// control deps
uint32_t control_deps[200];
// JSON Loader
void (*LoadAttrs)(struct TVMGraphRuntimeNode * node, JSONReader *reader, TVMOpParam* param);
// JSON Loader
int (*Load)(struct TVMGraphRuntimeNode * node, JSONReader *reader);
} TVMGraphRuntimeNode;
// Graph attribute
typedef struct TVMGraphRuntimeGraphAttr {
uint32_t storage_num_not_alloctaed;
uint32_t storage_id[GRAPH_RUNTIME_MAX_NODES];
uint32_t device_index[GRAPH_RUNTIME_MAX_NODES];
char dltype[GRAPH_RUNTIME_MAX_NODES][10]; // "int8", "int16", "float32"
uint32_t dltype_count;
int64_t shape[GRAPH_RUNTIME_MAX_NODES][TVM_CRT_MAX_NDIM];
uint32_t ndim[GRAPH_RUNTIME_MAX_NODES];
uint32_t shape_count;
} TVMGraphRuntimeGraphAttr;
typedef DLTensor* DLTensorPtr;
/*!
* \brief Tiny graph runtime.
*
* This runtime can be acccesibly in various language via
* TVM runtime PackedFunc API.
*/
/* class GraphRuntime : public ModuleNode { */
typedef struct TVMGraphRuntime {
void (*Run)(struct TVMGraphRuntime * runtime);
/*!
* \brief Initialize the graph executor with graph and context.
* \param graph_json The execution graph.
* \param module The module containing the compiled functions for the host
* processor.
* \param ctxs The context of the host and devices where graph nodes will be
* executed on.
*/
void (*Init)(struct TVMGraphRuntime * runtime,
const char * graph_json,
const TVMModule * module,
const TVMContext * ctxs);
/*!
* \brief Get the input index given the name of input.
* \param name The name of the input.
* \return The index of input.
*/
int (*GetInputIndex)(struct TVMGraphRuntime * runtime, const char * name);
/*!
* \brief set index-th input to the graph.
* \param index The input index.
* \param data_in The input data.
*/
void (*SetInput)(struct TVMGraphRuntime * runtime, const char * name, DLTensor* data_in);
/*!
* \brief Return NDArray for given output index.
* \param index The output index.
*
* \return NDArray corresponding to given output node index.
*/
int (*GetOutput)(struct TVMGraphRuntime * runtime, const int32_t index, DLTensor * out);
/*!
* \brief Load parameters from parameter blob.
* \param param_blob A binary blob of parameter.
*/
int (*LoadParams)(struct TVMGraphRuntime * runtime, const char * param_blob,
const uint32_t param_size);
// The graph attribute fields.
int (*Load)(struct TVMGraphRuntime * runtime, JSONReader *reader);
/*! \brief Setup the temporal storage */
void (*SetupStorage)(struct TVMGraphRuntime * runtime);
/*! \brief Setup the executors. */
int (*SetupOpExecs)(struct TVMGraphRuntime * runtime);
/*!
* \brief Create an execution function given input.
* \param attrs The node attributes.
* \param args The arguments to the functor, including inputs and outputs.
* \param num_inputs Number of inputs.
* \return The created executor.
*/
int32_t (*CreateTVMOp)(struct TVMGraphRuntime * runtime, const TVMOpParam * attrs,
DLTensorPtr * args, const uint32_t args_count,
uint32_t num_inputs, TVMPackedFunc * pf);
// Get node entry index.
uint32_t (*GetEntryId)(struct TVMGraphRuntime * runtime, uint32_t nid, uint32_t index);
// /*! \brief The graph nodes. */
/* GraphRuntimeNode nodes_[GRAPH_RUNTIME_MAX_NODES]; */
TVMGraphRuntimeNode nodes[GRAPH_RUNTIME_MAX_NODES];
uint32_t nodes_count;
/*! \brief The argument nodes. */
uint32_t input_nodes[GRAPH_RUNTIME_MAX_INPUT_NODES];
uint32_t input_nodes_count;
/*! \brief Used for quick entry indexing. */
uint32_t node_row_ptr[GRAPH_RUNTIME_MAX_NODE_ROW_PTR];
uint32_t node_row_ptr_count;
/*! \brief Output entries. */
TVMGraphRuntimeNodeEntry outputs[GRAPH_RUNTIME_MAX_OUTPUTS];
uint32_t outputs_count;
/*! \brief Additional graph attributes. */
TVMGraphRuntimeGraphAttr attrs;
/*! \brief The code module that contains both host and device code. */
TVMModule module;
/*! \brief Execution context of all devices including the host. */
TVMContext ctxs[GRAPH_RUNTIME_MAX_CONTEXTS];
uint32_t ctxs_count;
/*! \brief Common storage pool for all devices. */
TVMNDArray storage_pool[GRAPH_RUNTIME_MAX_NODES];
uint32_t storage_pool_count;
/*! \brief Data entry of each node. */
TVMNDArray data_entry[GRAPH_RUNTIME_MAX_NODES];
uint32_t data_entry_count;
/*! \brief Operator on each node. */
TVMPackedFunc op_execs[GRAPH_RUNTIME_MAX_NODES];
uint32_t op_execs_count;
} TVMGraphRuntime;
// public functions
TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, const TVMModule * m,
const TVMContext * ctxs);
void TVMGraphRuntimeRelease(TVMGraphRuntime ** runtime);
// private functions
void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in);
int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob,
const uint32_t param_size);
void TVMGraphRuntime_Run(TVMGraphRuntime * runtime);
int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTensor * out);
#endif // TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file load_json.h
* \brief Lightweight JSON Reader that read save into C++ data structs.
*/
#ifndef TVM_RUNTIME_CRT_LOAD_JSON_H_
#define TVM_RUNTIME_CRT_LOAD_JSON_H_
#include <stdio.h>
#include <ctype.h>
enum {
JSON_READ_TYPE_U8 = 1,
JSON_READ_TYPE_S8 = 2,
JSON_READ_TYPE_U16 = 3,
JSON_READ_TYPE_S16 = 4,
JSON_READ_TYPE_U32 = 5,
JSON_READ_TYPE_S32 = 6,
JSON_READ_TYPE_F32 = 7,
JSON_READ_TYPE_F64 = 8,
JSON_READ_TYPE_GRAPH_RUNTIME_NODE = 9,
JSON_READ_TYPE_GRAPH_RUNTIME_NODE_ENTRY = 10,
JSON_READ_TYPE_GRAPH_RUNTIME_GRAPH_ATTR = 11
};
typedef struct Seq {
uint32_t * data;
uint64_t allocated;
uint32_t size;
void (*push_back)(struct Seq * seq, uint32_t src);
uint32_t * (*back)(struct Seq * seq);
void (*pop_back)(struct Seq * seq);
} Seq;
/*!
* \brief Lightweight JSON Reader to read any STL compositions and structs.
* The user need to know the schema of the
*/
typedef struct JSONReader {
/*! \brief internal reader string */
char * is_;
char * isptr;
/*! \brief "\\r" counter */
size_t line_count_r_;
/*! \brief "\\n" counter */
size_t line_count_n_;
/*!
* \brief record how many element processed in
* current array/object scope.
*/
Seq * scope_counter_;
char (*NextChar)(struct JSONReader * reader);
char (*NextNonSpace)(struct JSONReader * reader);
char (*PeekNextChar)(struct JSONReader * reader);
char (*PeekNextNonSpace)(struct JSONReader * reader);
int (*ReadUnsignedInteger)(struct JSONReader * reader, unsigned int * out_value);
int (*ReadInteger)(struct JSONReader * reader, int64_t * out_value);
int (*ReadString)(struct JSONReader * reader, char * out_value);
void (*BeginArray)(struct JSONReader * reader);
void (*BeginObject)(struct JSONReader * reader);
uint8_t (*NextObjectItem)(struct JSONReader * reader, char * out_key);
uint8_t (*NextArrayItem)(struct JSONReader * reader);
} JSONReader;
/*!
* \brief Constructor of JSONReader class
* \param is the input source.
*/
JSONReader JSONReader_Create(const char * is);
void JSONReader_Release(JSONReader * reader);
#endif // TVM_RUNTIME_CRT_LOAD_JSON_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/crt/module.h
* \brief Runtime container of the functions
*/
#ifndef TVM_RUNTIME_CRT_MODULE_H_
#define TVM_RUNTIME_CRT_MODULE_H_
#include <string.h>
#include <tvm/runtime/c_runtime_api.h>
struct TVMPackedFunc;
typedef struct TVMPackedFunc TVMPackedFunc;
/*!
* \brief Module container of TVM.
*/
typedef struct TVMModule {
/*!
* \brief Get packed function from current module by name.
*
* \param name The name of the function.
* \param pf The result function.
*
* This function will return PackedFunc(nullptr) if function do not exist.
*/
void (*GetFunction)(const char * name, TVMPackedFunc * pf);
} TVMModule;
#endif // TVM_RUNTIME_CRT_MODULE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ndarray.c
* \brief NDArray container infratructure.
*/
#include "ndarray.h"
TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t * shape,
DLDataType dtype, DLContext ctx) {
TVMNDArray ret;
memset(&ret, 0, sizeof(TVMNDArray));
ret.dl_tensor.ndim = ndim;
ret.dl_tensor.shape = (int64_t*)malloc(sizeof(int64_t)*ndim); // NOLINT(*)
memcpy(ret.dl_tensor.shape, shape, sizeof(int64_t)*ndim);
ret.dl_tensor.dtype = dtype;
ret.dl_tensor.ctx = ctx;
ret.dl_tensor.data = 0;
return ret;
}
TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t * shape,
DLDataType dtype, DLContext ctx) {
TVMNDArray ret = TVMNDArray_Create(ndim, shape, dtype, ctx);
int64_t num_elems = 1;
int elem_bytes = (dtype.bits + 7) / 8;
uint32_t idx;
for (idx = 0; idx < ret.dl_tensor.ndim; ++idx) {
num_elems *= shape[idx];
}
ret.dl_tensor.data = TVMBackendAllocWorkspace(kDLCPU, 0, num_elems, dtype.code, dtype.bits);
memset(ret.dl_tensor.data, 0, num_elems * elem_bytes);
return ret;
}
int TVMNDArray_Load(TVMNDArray * ret, const char ** strm) {
int32_t status = 0;
uint64_t header, reserved;
header = ((uint64_t*)*strm)[0]; *strm += sizeof(header); // NOLINT(*)
if (header != kTVMNDArrayMagic) {
fprintf(stderr, "Invalid DLTensor file format\n");
status = -1;
}
reserved = ((uint64_t*)*strm)[0]; *strm += sizeof(reserved); // NOLINT(*)
DLContext ctx;
uint32_t ndim;
DLDataType dtype;
ctx = ((DLContext*)*strm)[0]; *strm += sizeof(ctx); // NOLINT(*)
ndim = ((uint32_t*)*strm)[0]; *strm += sizeof(ndim); // NOLINT(*)
dtype = ((DLDataType*)*strm)[0]; *strm += sizeof(dtype); // NOLINT(*)
if ((ndim <= 0) || (ndim > TVM_CRT_MAX_NDIM)) {
fprintf(stderr, "Invalid ndim=%d: expected to be 1 ~ %d.\n", ndim, TVM_CRT_MAX_NDIM);
status = -1;
}
if (ctx.device_type != kDLCPU) {
fprintf(stderr, "Invalid DLTensor context: can only save as CPU tensor\n");
status = -1;
}
int64_t shape[TVM_CRT_MAX_NDIM];
uint32_t idx;
if (ndim != 0) {
for (idx = 0; idx < ndim; idx++) {
shape[idx] = ((int64_t*)*strm)[0]; *strm += sizeof(shape[idx]); // NOLINT(*)
}
}
*ret = TVMNDArray_Empty(ndim, shape, dtype, ctx);
int64_t num_elems = 1;
int elem_bytes = (ret->dl_tensor.dtype.bits + 7) / 8;
for (idx = 0; idx < ret->dl_tensor.ndim; ++idx) {
num_elems *= ret->dl_tensor.shape[idx];
}
int64_t data_byte_size;
data_byte_size = ((int64_t*)*strm)[0]; *strm += sizeof(data_byte_size); // NOLINT(*)
if (!(data_byte_size == num_elems * elem_bytes)) {
fprintf(stderr, "invalid DLTensor file format: data_byte_size=%ld, "
"while num_elems*elem_bytes=%ld\n",
data_byte_size, (num_elems * elem_bytes));
status = -1;
}
memcpy(ret->dl_tensor.data, *strm, data_byte_size);
*strm += data_byte_size;
return status;
}
TVMNDArray TVMNDArray_CreateView(TVMNDArray * arr, const tvm_index_t * shape,
uint32_t ndim, DLDataType dtype) {
TVMNDArray ret = TVMNDArray_Create(ndim, shape, dtype, arr->dl_tensor.ctx);
ret.dl_tensor.data = arr->dl_tensor.data;
return ret;
}
int TVMNDArray_Release(TVMNDArray * arr) {
free(arr->dl_tensor.data);
free(arr->dl_tensor.shape);
return 0;
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/runtime/crt/ndarray.h
* \brief Abstract device memory management API
*/
#ifndef TVM_RUNTIME_CRT_NDARRAY_H_
#define TVM_RUNTIME_CRT_NDARRAY_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <dlpack/dlpack.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
/*! \brief Magic number for NDArray file */
static const uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
/*! \brief Magic number for NDArray list file */
static const uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
typedef struct TVMNDArray {
DLTensor dl_tensor;
} TVMNDArray;
TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t * shape,
DLDataType dtype, DLContext ctx);
TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t * shape,
DLDataType dtype, DLContext ctx);
int TVMNDArray_Load(TVMNDArray * ret, const char ** strm);
TVMNDArray TVMNDArray_CreateView(TVMNDArray * arr, const tvm_index_t * shape,
uint32_t ndim, DLDataType dtype);
int TVMNDArray_Release(TVMNDArray * arr);
#endif // TVM_RUNTIME_CRT_NDARRAY_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/runtime/packed_func.h
* \brief Type-erased function used across TVM API.
*/
#ifndef TVM_RUNTIME_CRT_PACKED_FUNC_H_
#define TVM_RUNTIME_CRT_PACKED_FUNC_H_
#include <tvm/runtime/c_runtime_api.h>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include "module.h"
static inline DLDataType String2DLDataType(const char * s) {
DLDataType t;
// handle None type
if (strlen(s) == 0) {
t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
return t;
}
t.bits = 32; t.lanes = 1;
const char* scan;
if (!strncmp(s, "int", 3)) {
t.code = kDLInt; scan = s + 3;
} else if (!strncmp(s, "uint", 4)) {
t.code = kDLUInt; scan = s + 4;
} else if (!strncmp(s, "float", 5)) {
t.code = kDLFloat; scan = s + 5;
} else if (!strncmp(s, "handle", 6)) {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s + 6;
} else if (!strcmp(s, "bool")) {
t.code = kDLUInt;
t.bits = 1;
t.lanes = 1;
return t;
} else {
scan = s;
fprintf(stderr, "unknown type %s\n", s);
}
char* xdelim;
uint8_t bits = (uint8_t)(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
char* endpt = xdelim;
if (*xdelim == 'x') {
t.lanes = (uint16_t)(strtoul(xdelim + 1, &endpt, 10));
}
if (!(endpt == s + strlen(s))) {
fprintf(stderr, "unknown type %s\n", s);
}
return t;
}
typedef struct TVMArgs {
TVMValue values[TVM_CRT_MAX_ARGS];
int tcodes[TVM_CRT_MAX_ARGS]; /* Data type should be identical to type_codes in TVMPackedCFunc */
uint32_t values_count;
} TVMArgs;
static inline TVMArgs TVMArgs_Create(TVMValue * values, uint32_t * tcodes, uint32_t values_count) {
uint32_t idx;
TVMArgs args;
memset(&args, 0, sizeof(args));
for (idx = 0; idx < values_count; idx++) {
memcpy(args.values + idx, values + idx, sizeof(TVMValue));
args.tcodes[idx] = tcodes[idx];
}
args.values_count = values_count;
return args;
}
static inline int TVMNoOperation(TVMValue * args, int * type_codes, int num_args,
TVMRetValueHandle ret, void * res) {
return 0;
}
typedef struct TVMPackedFunc {
char name[200];
TVMPackedCFunc fexec;
TVMArgs args;
void (*Call)(struct TVMPackedFunc * pf);
void (*SetArgs)(struct TVMPackedFunc * pf, const struct TVMArgs * args);
} TVMPackedFunc;
static inline void TVMPackedFunc_Call(TVMPackedFunc * pf) {
pf->fexec(pf->args.values, pf->args.tcodes, pf->args.values_count, 0, 0);
}
static inline void TVMPackedFunc_SetArgs(TVMPackedFunc * pf, const TVMArgs * args) {
memcpy(&(pf->args), args, sizeof(TVMArgs));
}
TVMPackedFunc g_fexecs[GRAPH_RUNTIME_MAX_NODES];
uint32_t g_fexecs_count = 0;
void TVMPackedFunc_SetupExecs();
// Implement TVMModule::GetFunction
// Put implementation in this file so we have seen the TVMPackedFunc
static inline void TVMModule_GetFunction(const char * name, TVMPackedFunc * pf) {
int idx;
memset(pf, 0, sizeof(TVMPackedFunc));
assert(strlen(name) <= sizeof(pf->name));
snprintf(pf->name, strlen(name), "%s", name);
pf->Call = TVMPackedFunc_Call;
pf->SetArgs = TVMPackedFunc_SetArgs;
pf->fexec = &TVMNoOperation;
for (idx = 0; idx < GRAPH_RUNTIME_MAX_NODES; idx++) {
if (!strcmp(g_fexecs[idx].name, name)) {
pf->fexec = g_fexecs[idx].fexec;
break;
}
}
if (idx == GRAPH_RUNTIME_MAX_NODES) {
fprintf(stderr, "function handle for %s not found\n", name);
}
}
#endif // TVM_RUNTIME_CRT_PACKED_FUNC_H_
......@@ -19,7 +19,7 @@
set -e
set -u
export PYTHONPATH=python:topi/python:apps/extension/python
export PYTHONPATH=`pwd`/python:`pwd`/topi/python:`pwd`/apps/extension/python
export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
export TVM_BIND_THREADS=0
export TVM_NUM_THREADS=2
......@@ -30,6 +30,12 @@ find . -type f -path "*.pyc" | xargs rm -f
# Test TVM
make cython3
# Test MISRA-C runtime
cd apps/bundle_deploy
rm -rf build
make test
cd ../..
# Test extern package
cd apps/extension
rm -rf lib
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment