Commit 204ad63b by Tianqi Chen Committed by GitHub

[NNVM] Example NNVM integration. (#182)

parent b1402b37
......@@ -97,3 +97,4 @@ build_*
Win32
*.dir
perf
nnvm
......@@ -154,7 +154,8 @@ LIBHALIDEIR:
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)
cpplint:
python dmlc-core/scripts/lint.py tvm cpp include src verilog examples/extension/src
python dmlc-core/scripts/lint.py tvm cpp include src verilog\
examples/extension/src examples/graph_executor/src
pylint:
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
......
......@@ -7,7 +7,12 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/HalideIR/src
PKG_LDFLAGS =-L${TVM_ROOT}/lib
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
PKG_LDFLAGS += -undefined dynamic_lookup
endif
lib/libtvm_ext.so: src/tvm_ext.cc
@mkdir -p $(@D)
$(CXX) $(PKG_CFLAGS) -shared -o $@ $^ $(PKG_LDFLAGS) -ltvm
$(CXX) $(PKG_CFLAGS) -shared -o $@ $^ $(PKG_LDFLAGS)
......@@ -3,5 +3,5 @@ Example Extension Library
This folder contains an example extension library of TVM.
It demonstrates how can other library extend TVM in both C++ and python API.
- The library extends TVM's functionality by link libtvm
- The library extends TVM's functionality.
- The python module load the new shared library and can interpolate with TVM's python API.
......@@ -2,17 +2,17 @@
from __future__ import absolute_import
import os
import ctypes
# Import TVM first to get library symbols
import tvm
def load_lib():
"""Load library, the functions will be registered into TVM"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
lib = ctypes.CDLL(os.path.join(curr_path, "../lib/libtvm_ext.so"),
ctypes.RTLD_GLOBAL)
lib = ctypes.CDLL(os.path.join(curr_path, "../../lib/libtvm_ext.so"))
return lib
_LIB = load_lib()
import tvm
# Expose two functions into python
bind_add = tvm.get_global_func("tvm_ext.bind_add")
sym_add = tvm.get_global_func("tvm_ext.sym_add")
......
# Minimum Makefile for the extension package
TVM_ROOT=$(shell cd ../..; pwd)
NNVM_PATH=nnvm
PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/include\
-I${TVM_ROOT}/dmlc-core/include\
-I${TVM_ROOT}/dlpack/include\
-I${TVM_ROOT}/HalideIR/src
PKG_LDFLAGS =
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
PKG_LDFLAGS += -undefined dynamic_lookup
WHOLE_ARCH= -all_load
NO_WHOLE_ARCH= -noall_load
else
WHOLE_ARCH= --whole-archive
NO_WHOLE_ARCH= --no-whole-archive
endif
NNVM_CONTRIB_SRC = $(wildcard src/*.cc)
NNVM_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(NNVM_CONTRIB_SRC))
ALL_DEP = $(NNVM_CONTRIB_OBJ)
PKG_CFLAGS += -I${NNVM_PATH}/include
ALL_DEP += ${NNVM_PATH}/lib/libnnvm.a
.PHONY: clean all
all: lib/libtvm_graph_exec.so
nnvm:
git clone https://github.com/dmlc/nnvm --recursive
nnvm/lib/libnnvm.a: | nnvm
+ cd nnvm; make ; cd -
build/%.o: src/%.cc | nnvm
@mkdir -p $(@D)
$(CXX) $(PKG_CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(PKG_CFLAGS) -c $< -o $@
lib/libtvm_graph_exec.so: $(ALL_DEP)
@mkdir -p $(@D)
$(CXX) $(PKG_CFLAGS) -shared -o $@ $(filter %.o, $^) $(PKG_LDFLAGS) \
-Wl,${WHOLE_ARCH} $(filter %.a, $^) -Wl,${NO_WHOLE_ARCH} $(PKG_LDFLAGS)
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
-include build/*.d
-include build/*/*.d
Example Graph Executor
======================
This folder contains a minimum example of graph executor library based on TVM and NNVM.
It demonstrates how to build a computation graph compilation and execution framework.
- The to build library, need to clone and build into root of the repo.
"""The graph build library"""
from __future__ import absolute_import as _abs
import tvm
from . import _base
from nnvm.symbol import *
from . import op_tvm_def
from .build import build, bind
from __future__ import absolute_import as _abs
import os
import sys
if sys.version_info[0] == 3:
import builtins as __builtin__
else:
import __builtin__
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
if hasattr(__builtin__, "NNVM_BASE_PATH"):
assert __builtin__.NNVM_BASE_PATH == curr_path
else:
__builtin__.NNVM_BASE_PATH = curr_path
if hasattr(__builtin__, "NNVM_LIBRARY_NAME"):
assert __builtin__.NNVM_LIBRARY_NAME == curr_path
else:
__builtin__.NNVM_LIBRARY_NAME = "libtvm_graph_exec"
"""Logics related to build."""
import nnvm.graph as graph
import tvm
import json
DTYPE_DICT = {
"float32": 0
}
_create_exec = tvm.get_global_func("tvm_graph._create_executor")
def build(sym, target, shape, dtype="float32"):
# Do shape inference in python.
g = graph.create(sym)
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
list_shape = [[]] * jnode_row_ptr[-1]
list_dtype = [DTYPE_DICT[dtype]] * jnode_row_ptr[-1]
for k, v in shape.items():
list_shape[jnode_row_ptr[nindex[k]]] = v
g._set_json_attr("shape", list_shape, 'list_shape')
g._set_json_attr("dtype", list_dtype, 'list_int')
g._set_json_attr("target", target, 'str')
g = g.apply("InferShape").apply("InferType")
g = g.apply("GraphPartition").apply("GraphFuse")
return g
def bind(g, ctx):
m = _create_exec(g.handle, ctx)
return m
@tvm.register_func("tvm_graph.lower")
def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name)
return f if isinstance(
f, (tvm.collections.Array, tuple, list)) else [f]
@tvm.register_func("tvm_graph.build_target")
def _build(funcs, target):
return tvm.build(funcs, target)
"""NNVM operator definitions."""
import tvm
@tvm.register_func("tvm_graph.compute.add")
def compute_add(a, b):
return tvm.compute(a.shape, lambda *i: a(*i) + b(*i))
@tvm.register_func("tvm_graph.compute.exp")
def compute_exp(a):
return tvm.compute(a.shape, lambda *i: tvm.exp(a(*i)))
@tvm.register_func("tvm_graph.schedule.ewise")
def schedule_ewise(outs, target):
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineElemWise(s)
return s
/*!
* Copyright (c) 2017 by Contributors
* \file NNVM Graph executor.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <nnvm/graph.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/tuple.h>
#include <numeric>
namespace tvm {
namespace contrib {
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
using nnvm::StorageVector;
using nnvm::ShapeVector;
using nnvm::TShape;
using nnvm::NodeAttrs;
/*! \brief DLPack compatible data types */
using DLTypeVector = std::vector<DLDataType>;
/*! \brief The executor function */
using FOpExec = std::function<void()>;
/*! \brief macro to do C API call */
#define TVM_CCALL(func) \
{ \
int ret = (func); \
CHECK_EQ(ret, 0) \
<< TVMGetLastError(); \
}
/*! \brief Graph Executor with TVM runtime */
class GraphExecutor final : public runtime::ModuleNode {
public:
const char* type_key() const final {
return "GraphExecutor";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
// Destructor
~GraphExecutor();
// Setup with a given graph
void Init(const nnvm::Graph& g, TVMContext ctx);
// Copy data to index-th input
void SetInput(int index, DLTensor* data_in);
// Copy index-th output to data_out
void GetOutput(int index, DLTensor* data_out);
// Execute the graph.
void Run();
private:
// functions
void SetupStorage();
void SetupOpExecs();
// Constructor to create TVM op
FOpExec CreateTVMOp(const nnvm::NodeAttrs& attrs,
std::vector<DLTensor> inputs,
size_t num_inputs);
// The graph to be executed.
nnvm::Graph graph_;
// The execution context
TVMContext ctx_;
// Common storage pool
std::vector<DLTensor*> storage_pool_;
// The data entry
std::vector<DLTensor> data_entry_;
// The operation lambda on each node
std::vector<FOpExec> op_execs_;
// The code module.
tvm::runtime::Module module_;
};
PackedFunc GraphExecutor::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
// return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->SetInput(args[0], args[1]);
});
} else if (name == "get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->GetOutput(args[0], args[1]);
});
} else if (name == "run") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Run();
});
} else {
return PackedFunc();
}
}
GraphExecutor::~GraphExecutor() {
for (DLTensor* t : storage_pool_) {
TVM_CCALL(TVMArrayFree(t));
}
}
void GraphExecutor::Run() {
// setup the array and requirements.
for (size_t i = 0; i < op_execs_.size(); ++i) {
if (op_execs_[i]) op_execs_[i]();
}
}
void GraphExecutor::Init(const nnvm::Graph& g, TVMContext ctx) {
graph_ = g;
ctx_ = ctx;
module_ = g.GetAttr<tvm::runtime::Module>("module");
this->SetupStorage();
this->SetupOpExecs();
}
void GraphExecutor::SetInput(int index, DLTensor* data_in) {
const auto& idx = graph_.indexed_graph();
CHECK_LT(static_cast<size_t>(index), idx.input_nodes().size());
uint32_t eid = idx.entry_id(idx.input_nodes()[index], 0);
TVM_CCALL(TVMArrayCopyFromTo(data_in, &data_entry_[eid], nullptr));
}
void GraphExecutor::GetOutput(int index, DLTensor* data_out) {
const auto& idx = graph_.indexed_graph();
CHECK_LT(static_cast<size_t>(index), idx.outputs().size());
uint32_t eid = idx.entry_id(idx.outputs()[index]);
TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr));
}
void GraphExecutor::SetupStorage() {
const auto& idx = graph_.indexed_graph();
// Grab saved optimization plan from graph.
auto vstorage = graph_.MoveCopyAttr<StorageVector>("storage_id");
const auto& vshape = graph_.GetAttr<ShapeVector>("shape");
const auto& vtype = graph_.GetAttr<DLTypeVector>("dltype");
data_entry_.resize(idx.num_node_entries());
// Find the maximum space size.
int max_id = 0;
for (size_t i = 0; i < vshape.size(); ++i) {
max_id = std::max(vstorage[i] + 1, max_id);
}
for (const auto& e : idx.input_nodes()) {
vstorage[idx.entry_id(e, 0)] = max_id++;
}
// size of each storage pool entry
std::vector<size_t> pool_entry_bytes;
// Find the maximum space size.
for (size_t i = 0; i < vshape.size(); ++i) {
int storage_id = vstorage[i];
size_t size = vshape[i].Size();
CHECK_GE(storage_id, 0) << "Do not support runtime shape op";
DLDataType t = vtype[i];
size_t bits = t.bits * t.lanes;
CHECK_EQ(bits % 8U, 0U);
size_t bytes = (bits / 8U) * size;
size_t sid = static_cast<size_t>(storage_id);
if (sid >= pool_entry_bytes.size()) {
pool_entry_bytes.resize(sid + 1, 0);
}
pool_entry_bytes[sid] = std::max(pool_entry_bytes[sid], size * bytes);
}
// Allocate the space.
for (size_t i = 0; i < pool_entry_bytes.size(); ++i) {
TShape shape{static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4};
DLTensor* tensor;
TVM_CCALL(TVMArrayAlloc(
shape.data(), 1, DLDataType{kFloat, 32U, 1U}, ctx_, &tensor));
storage_pool_.push_back(tensor);
}
// Assign the pooled entries.
for (size_t i = 0; i < data_entry_.size(); ++i) {
int storage_id = vstorage[i];
data_entry_[i] = *storage_pool_[storage_id];
data_entry_[i].shape = const_cast<int64_t*>(vshape[i].data());
data_entry_[i].ndim = vshape[i].ndim();
data_entry_[i].dtype = vtype[i];
}
}
void GraphExecutor::SetupOpExecs() {
static const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op");
const auto& idx = graph_.indexed_graph();
op_execs_.resize(idx.num_nodes());
// setup the array and requirements.
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
std::vector<DLTensor> args;
for (const auto& e : inode.inputs) {
args.push_back(data_entry_[idx.entry_id(e)]);
}
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
args.push_back(data_entry_[eid]);
}
CHECK_EQ(inode.source->op(), tvm_op)
<< "transform the graph to tvm op";
op_execs_[nid] = CreateTVMOp(
inode.source->attrs, args, inode.inputs.size());
}
}
FOpExec GraphExecutor::CreateTVMOp(const nnvm::NodeAttrs& attrs,
std::vector<DLTensor> args,
size_t num_inputs) {
struct OpArgs {
std::vector<DLTensor> args;
std::vector<TVMValue> arg_values;
std::vector<int> arg_tcodes;
std::vector<int64_t> shape_data;
};
auto it = attrs.dict.find("func_name");
CHECK(it != attrs.dict.end())
<< "tvm_op must need func_name attr";
bool flatten = (attrs.dict.at("flatten_data") == "1");
std::shared_ptr<OpArgs> arg_ptr = std::make_shared<OpArgs>();
// setup address.
arg_ptr->args = std::move(args);
if (flatten) {
arg_ptr->shape_data.resize(arg_ptr->args.size());
}
for (size_t i = 0; i < arg_ptr->args.size(); ++i) {
TVMValue v;
DLTensor* t = &(arg_ptr->args[i]);
v.v_handle = t;
arg_ptr->arg_values.push_back(v);
arg_ptr->arg_tcodes.push_back(kArrayHandle);
if (flatten) {
int64_t s = 1;
arg_ptr->shape_data[i] = std::accumulate(
t->shape, t->shape + t->ndim, 1, std::multiplies<int64_t>());
t->ndim = 1;
t->shape = &(arg_ptr->shape_data[i]);
}
}
// get compiled function from module.
runtime::PackedFunc pf = module_.GetFunction(it->second, false);
auto fexec = [arg_ptr, pf] () {
runtime::TVMRetValue rv;
runtime::TVMArgs targs(arg_ptr->arg_values.data(),
arg_ptr->arg_tcodes.data(),
static_cast<int>(arg_ptr->arg_values.size()));
pf.CallPacked(targs, &rv);
};
return fexec;
}
// Create executor
tvm::runtime::Module CreateExecutor(nnvm::Graph g, TVMContext ctx) {
std::shared_ptr<GraphExecutor> exec =
std::make_shared<GraphExecutor>();
exec->Init(g, ctx);
return tvm::runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("tvm_graph._create_executor")
.set_body([](TVMArgs args, TVMRetValue *rv) {
void* graph_handle = args[0];
TVMContext ctx = args[1];
nnvm::Graph g = static_cast<nnvm::Graph*>(graph_handle)[0];
*rv = CreateExecutor(g, ctx);
});
// ewise tvm op
NNVM_REGISTER_OP(tvm_op)
.set_num_inputs(-1);
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file op_attr_types.h
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef TVM_OP_ATTR_TYPES_H_
#define TVM_OP_ATTR_TYPES_H_
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/graph.h>
#include <vector>
#include <string>
namespace tvm {
namespace contrib {
using runtime::PackedFunc;
using nnvm::StorageVector;
using nnvm::ShapeVector;
using nnvm::DTypeVector;
using nnvm::TShape;
using nnvm::NodeAttrs;
/*! \brief DLPack compatible data types */
using DLTypeVector = std::vector<DLDataType>;
/*!
* \brief Computation description interface
* \param attrs The attribute of the node.
* \param inputs The input tensors(placeholders)
* \return The output description of the tensor.
*/
using FTVMCompute = std::function<
Array<Tensor>
(const NodeAttrs& attrs,
const Array<Tensor>& inputs)>;
/*!
* \brief Build the computation schedule for
* op whose root is at current op.
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return schedule The computation schedule.
*/
using FTVMSchedule = std::function<
Schedule(const NodeAttrs& attrs,
const Array<Tensor>& outs,
const std::string& target)>;
// The storage result of op
enum OpPatternKind : int {
// Elementwise operation
kElemWise,
// Broadcast operation
kBroadcast,
// Complex operation, can fuse bcast in input/outputs
// but cannot chain another complex op
kComplex,
// Extern operation, cannot fuse anything.
kExtern
};
using TOpPattern = int;
/*!
* \brief Get PackedFunction from global registry and
* report error if it does not exist
* \param name The name of the function.
* \return The created PackedFunc.
*/
inline const PackedFunc& GetPackedFunc(const std::string& name) {
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
return *pf;
}
/*!
* \brief Create a Graph execution module by a given graph and the code module.
* \param g The graph to be executed.
* \param m The tvm module containing the functions.
* \return The created executor module.
*/
tvm::runtime::Module CreateExecutor(nnvm::Graph g);
} // namespace contrib
} // namespace tvm
#endif // TVM_OP_ATTR_TYPES_H_
/*!
* Copyright (c) 2017 by Contributors
* \file Operator Declarations.
*/
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include "./op_attr_types.h"
namespace tvm {
namespace contrib {
using namespace nnvm;
inline bool SameShape(const NodeAttrs& attrs,
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false;
for (TShape& pshape : *oshape) {
pshape = (*ishape)[0];
}
for (TShape& pshape : *ishape) {
pshape = (*ishape)[0];
}
return true;
}
NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
.set_attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(__add_symbol__)
.describe("add two data together")
.set_num_inputs(2)
.include("ElementwiseOpAttr");
NNVM_REGISTER_OP(exp)
.describe("Take exp")
.set_num_inputs(1)
.include("ElementwiseOpAttr");
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file Operator defintions in TVM.
*/
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include "./op_attr_types.h"
namespace tvm {
namespace contrib {
using namespace nnvm;
Array<Tensor>
ComputeAdd(const NodeAttrs& attrs,
const Array<Tensor>& inputs) {
static const PackedFunc& pf = GetPackedFunc("tvm_graph.compute.add");
CHECK_EQ(inputs.size(), 2U);
Tensor ret = pf(inputs[0], inputs[1]);
return {ret};
}
Array<Tensor>
ComputeExp(const NodeAttrs& attrs,
const Array<Tensor>& inputs) {
static const PackedFunc& pf = GetPackedFunc("tvm_graph.compute.exp");
CHECK_EQ(inputs.size(), 1U);
Tensor ret = pf(inputs[0]);
return {ret};
}
Schedule ScheduleEWise(const NodeAttrs& attrs,
const Array<Tensor>& outs,
const std::string& target) {
static const PackedFunc& pf = GetPackedFunc("tvm_graph.schedule.ewise");
return pf(outs, target);
}
NNVM_REGISTER_OP(__add_symbol__)
.set_attr<FTVMCompute>("FTVMCompute", ComputeAdd)
.set_attr<FTVMSchedule>("FTVMSchedule", ScheduleEWise);
NNVM_REGISTER_OP(exp)
.set_attr<FTVMCompute>("FTVMCompute", ComputeExp)
.set_attr<FTVMSchedule>("FTVMSchedule", ScheduleEWise);
} // namespace contrib
} // namespace tvm
import tvm_graph as tg
import numpy as np
import tvm
def test_compile():
x = tg.Variable('x')
y = tg.Variable('y')
z = tg.exp(y + x)
shape = (10, 128)
dtype = tvm.float32
g = tg.build(z, "llvm",
shape={'x': shape,
'y': shape})
m = tg.bind(g, tvm.cpu(0))
# get member functions
set_input, run, get_output = m['set_input'], m['run'], m['get_output']
na = tvm.nd.array(np.ones(shape).astype(dtype))
nb = tvm.nd.array(np.ones(shape).astype(dtype))
# set inputs
set_input(0, na)
set_input(1, nb)
# execute
run()
# get outputs
out = tvm.nd.array(np.zeros(shape).astype(dtype))
get_output(0, out)
np.testing.assert_allclose(
out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))
if __name__ == "__main__":
test_compile()
......@@ -133,6 +133,9 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, FunctionBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
elif isinstance(arg, ctypes.c_void_p):
values[i].v_handle = arg
type_codes[i] = TypeCode.HANDLE
elif callable(arg):
arg = convert_to_tvm_func(arg)
values[i].v_handle = arg.handle
......
......@@ -16,7 +16,7 @@ cdef int tvm_callback(TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* fhandle):
void* fhandle) with gil:
cdef list pyargs
cdef TVMValue value
cdef int tcode
......@@ -140,6 +140,9 @@ cdef inline void make_arg(object arg,
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
elif isinstance(arg, ctypes.c_void_p):
value[0].v_handle = c_handle(arg)
tcode[0] = kHandle
elif callable(arg):
arg = convert_to_tvm_func(arg)
value[0].v_handle = (<FunctionBase>arg).chandle
......
#!/bin/bash
export PYTHONPATH=python:examples/extension
export PYTHONPATH=python:examples/extension/python
export PYTHONPATH=${PYTHONPATH}:examples/graph_executor/python:examples/graph_executor/nnvm/python
export LD_LIBRARY_PATH=lib:${LD_LIBRARY_PATH}
# Test TVM
make cython || exit -1
# Test extern package package
cd examples/extension
make || exit -1
cd ../..
python -m nose -v examples/extension/tests || exit -1
# Test TVM
make cython || exit -1
# Test NNVM integration
cd examples/graph_executor
make || exit -1
cd ../..
python -m nose -v examples/graph_executor/tests || exit -1
TVM_FFI=cython python -m nose -v tests/python/integration || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/integration || exit -1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment