Commit 338bfd45 by Tianqi Chen Committed by GitHub

[CODEGEN] More robust llvm intrin handling, remove graph executor (#519)

parent 4468c576
......@@ -3,9 +3,9 @@ This folder contains various extension projects using TVM,
they also serve as examples on how to use TVM in your own project.
If you are interested in writing optimized kernels with TVM, checkout [TOPI: TVM Operator Inventory](../topi).
If you are interested in end to end deep learning model compilation, checkout [NNVM Compiler](https://github.com/dmlc/nnvm).
- [extension](extension) How to extend TVM C++ api along with python API.
- [graph_executor](graph_executor) Build nnvm graph executor with TVM.
- [ios_rpc](ios_rpc) iOS RPC server.
- [android_rpc](android_rpc) Android RPC server.
- [howto_deploy](howto_depploy) Tutorial on how to deploy TVM with minimum code dependency.
- [howto_deploy](howto_depploy) Tutorial on how to deploy TVM with minimum code dependency.
\ No newline at end of file
# Minimum Makefile for the extension package
TVM_ROOT=$(shell cd ../..; pwd)
NNVM_PATH=nnvm
DMLC_CORE=${TVM_ROOT}/dmlc-core
PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/include\
-I${DMLC_CORE}/include\
-I${TVM_ROOT}/dlpack/include\
-I${TVM_ROOT}/HalideIR/src
PKG_LDFLAGS =
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
PKG_LDFLAGS += -undefined dynamic_lookup
WHOLE_ARCH= -all_load
NO_WHOLE_ARCH= -noall_load
else
WHOLE_ARCH= --whole-archive
NO_WHOLE_ARCH= --no-whole-archive
endif
NNVM_CONTRIB_SRC = $(wildcard src/*.cc)
NNVM_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(NNVM_CONTRIB_SRC))
include $(DMLC_CORE)/make/dmlc.mk
ALL_DEP = $(NNVM_CONTRIB_OBJ)
PKG_CFLAGS += -I${NNVM_PATH}/include
ALL_DEP += ${DMLC_CORE}/libdmlc.a ${NNVM_PATH}/lib/libnnvm.a
.PHONY: clean all
all: lib/libtvm_graph_exec.so
nnvm:
git clone https://github.com/dmlc/nnvm --recursive
nnvm/lib/libnnvm.a: | nnvm
+ cd nnvm; make ; cd -
$(DMLC_CORE)/libdmlc.a:
+ cd $(DMLC_CORE); make libdmlc.a; cd $(TVM_ROOT)
build/%.o: src/%.cc | nnvm
@mkdir -p $(@D)
$(CXX) $(PKG_CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(PKG_CFLAGS) -c $< -o $@
lib/libtvm_graph_exec.so: $(ALL_DEP)
@mkdir -p $(@D)
$(CXX) $(PKG_CFLAGS) -shared -o $@ $(filter %.o, $^) $(PKG_LDFLAGS) \
-Wl,${WHOLE_ARCH} $(filter %.a, $^) -Wl,${NO_WHOLE_ARCH} $(PKG_LDFLAGS)
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
-include build/*.d
-include build/*/*.d
Example Graph Executor
======================
This folder contains a minimum example of graph executor library based on TVM and NNVM.
It demonstrates how to build a computation graph compilation and execution framework.
- The to build library, need to clone and build into root of the repo.
"""The graph build library"""
from __future__ import absolute_import as _abs
import tvm
from . import _base
from nnvm.symbol import *
from . import op_tvm_def
from .build import build, bind, save_params, compile_graph, remote_load_exec
from __future__ import absolute_import as _abs
import os
import sys
if sys.version_info[0] == 3:
import builtins as __builtin__
else:
import __builtin__
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
if hasattr(__builtin__, "NNVM_BASE_PATH"):
assert __builtin__.NNVM_BASE_PATH == curr_path
else:
__builtin__.NNVM_BASE_PATH = curr_path
if hasattr(__builtin__, "NNVM_LIBRARY_NAME"):
assert __builtin__.NNVM_LIBRARY_NAME == curr_path
else:
__builtin__.NNVM_LIBRARY_NAME = "libtvm_graph_exec"
"""Logics related to build."""
import nnvm.graph as graph
import tvm
import json
DTYPE_DICT = {
"float32": 0
}
_create_exec = tvm.get_global_func("tvm_graph._create_executor")
def build(sym, target, shape, dtype="float32"):
# Do shape inference in python.
g = graph.create(sym)
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
list_shape = [[]] * jnode_row_ptr[-1]
list_dtype = [DTYPE_DICT[dtype]] * jnode_row_ptr[-1]
for k, v in shape.items():
list_shape[jnode_row_ptr[nindex[k]]] = v
g._set_json_attr("shape", list_shape, 'list_shape')
g._set_json_attr("dtype", list_dtype, 'list_int')
g._set_json_attr("target", target, 'str')
g = g.apply("InferShape").apply("InferType")
g = g.apply("GraphPartition").apply("GraphFuse")
return g
def bind(g, ctx):
m = _create_exec(g.handle, ctx.device_type, ctx.device_id)
return m
_get_module = tvm.get_global_func("tvm_graph._get_module_from_graph")
def compile_graph(lib_fname, sym, target, shape, dtype="float32"):
g = build(sym, target, shape, dtype)
m = _get_module(g.handle)
m.save(lib_fname)
json_str = g.apply('SaveJSON').json_attr('json')
return json_str
@tvm.register_func("tvm_graph.lower")
def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name)
return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f]
@tvm.register_func("tvm_graph.build_target")
def _build(funcs, target):
return tvm.build(funcs, target=target)
_save_param_dict = tvm.get_global_func("tvm_graph._save_param_dict")
def save_params(fname, params):
args = []
args.append(fname)
args.append(len(params))
for kv in params.items():
args.append(kv[0])
args.append(kv[1])
_save_param_dict(*args)
def remote_load_exec(sess, sym_json, remote_module_name, param_blob, ctx):
"""Load a remote graph executor, with the local files.
Parameters
----------
sym_json : str
The symbol json file.
remote_module_fname : str
The relative library location to remote temp folder. The
library need to be uploaded first.
param_blob : bytes or bytearray
The binary file to the local parameters.
Returns
-------
exec : GraphExecutor
The remote graph executor containing remote function.
"""
if "load_executor" not in sess._remote_funcs:
sess._remote_funcs["load_executor"] = sess.get_function("tvm_graph._load_executor")
assert ctx.device_type / tvm.contrib.rpc.RPC_SESS_MASK == sess._tbl_index + 1
device_type = ctx.device_type % tvm.contrib.rpc.RPC_SESS_MASK
return sess._remote_funcs["load_executor"](sym_json,
remote_module_name,
bytearray(param_blob),
device_type,
ctx.device_id)
"""NNVM operator definitions."""
import tvm
@tvm.register_func("tvm_graph.compute.add")
def compute_add(a, b):
return tvm.compute(a.shape, lambda *i: a(*i) + b(*i))
@tvm.register_func("tvm_graph.compute.exp")
def compute_exp(a):
return tvm.compute(a.shape, lambda *i: tvm.exp(a(*i)))
@tvm.register_func("tvm_graph.schedule.ewise")
def schedule_ewise(outs, target):
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineElemWise(s)
return s
/*!
* Copyright (c) 2017 by Contributors
* \file graph_executor.h
*/
#ifndef TVM_GRAPH_EXECUTOR_H_
#define TVM_GRAPH_EXECUTOR_H_
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <nnvm/graph.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/tuple.h>
#include <nnvm/pass.h>
#include <numeric>
#include <string>
namespace tvm {
namespace contrib {
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
using nnvm::StorageVector;
using nnvm::ShapeVector;
using nnvm::TShape;
using nnvm::NodeAttrs;
/*! \brief DLPack compatible data types */
using DLTypeVector = std::vector<DLDataType>;
/*! \brief The executor function */
using FOpExec = std::function<void()>;
/*! \brief macro to do C API call */
#define TVM_CCALL(func) \
{ \
int ret = (func); \
CHECK_EQ(ret, 0) \
<< TVMGetLastError(); \
}
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
/*! \brief Graph Executor with TVM runtime */
class GraphExecutor : public runtime::ModuleNode {
public:
const char* type_key() const {
return "GraphExecutor";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self);
// Destructor
~GraphExecutor();
// Setup with a given graph
void Init(const nnvm::Graph& g, TVMContext ctx);
// Get index of variable
int GetIndex(std::string name);
// Copy data to index-th input
void SetInput(int index, DLTensor* data_in);
// Copy index-th output to data_out
void GetOutput(int index, DLTensor* data_out);
// Load parameters from stream
void LoadParams(dmlc::Stream* strm);
// Load parameters from binary file blob
void LoadParamsFromBlob(std::string param_blob);
// Execute the graph.
void Run();
private:
// functions
void SetupNameIndex();
void SetupStorage();
void SetupOpExecs();
// Constructor to create TVM op
FOpExec CreateTVMOp(const nnvm::NodeAttrs& attrs,
std::vector<DLTensor> inputs,
size_t num_inputs);
// The graph to be executed.
nnvm::Graph graph_;
// The execution context
TVMContext ctx_;
// Common storage pool
std::vector<DLTensor*> storage_pool_;
// The data shape
std::vector<TShape> data_shape_;
// The data entry
std::vector<DLTensor> data_entry_;
// The operation lambda on each node
std::vector<FOpExec> op_execs_;
// The code module.
tvm::runtime::Module module_;
std::unordered_map<std::string, size_t> name_idx_;
};
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
bool flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs)
.set_default(1);
DMLC_DECLARE_FIELD(num_outputs)
.set_default(1);
DMLC_DECLARE_FIELD(flatten_data)
.set_default(false);
}
};
} // namespace contrib
} // namespace tvm
#endif // TVM_GRAPH_EXECUTOR_H_
/*!
* Copyright (c) 2017 by Contributors
* \file graph_executor_ext.cc
*/
#include "./graph_executor.h"
namespace tvm {
namespace contrib {
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(&header, sizeof(header));
strm->Write(&reserved, sizeof(reserved));
strm->Write(&tensor->ctx, sizeof(tensor->ctx));
strm->Write(&tensor->ndim, sizeof(tensor->ndim));
strm->Write(&tensor->dtype, sizeof(tensor->dtype));
int ndim = tensor->ndim;
strm->Write(tensor->shape, sizeof(int64_t) * ndim);
int type_size = tensor->dtype.bits / 8;
int64_t size = 1;
for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i];
}
int64_t data_byte_size = type_size * size;
strm->Write(&data_byte_size, sizeof(data_byte_size));
strm->Write(tensor->data, data_byte_size);
return true;
}
TVM_REGISTER_GLOBAL("tvm_graph._save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string fname = args[0];
int num_params = args[1];
std::vector<std::string> names;
names.reserve(num_params);
std::vector<DLTensor*> arrays;
arrays.reserve(num_params);
for (int i = 2; i < (2 + 2*num_params); i += 2) {
names.emplace_back(args[i].operator std::string());
arrays.emplace_back(args[i+1].operator DLTensor*());
}
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
uint64_t header = kTVMNDArrayListMagic, reserved = 0;
fo->Write(&header, sizeof(header));
fo->Write(&reserved, sizeof(reserved));
fo->Write(names);
{
uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(&sz, sizeof(sz));
for (size_t i = 0; i < sz; ++i) {
SaveDLTensor(fo.get(), arrays[i]);
}
}
});
// Create executor
tvm::runtime::Module CreateExecutor(nnvm::Graph g, TVMContext ctx) {
std::shared_ptr<GraphExecutor> exec =
std::make_shared<GraphExecutor>();
exec->Init(g, ctx);
return tvm::runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("tvm_graph._create_executor")
.set_body([](TVMArgs args, TVMRetValue *rv) {
void* graph_handle = args[0];
int device_type = args[1];
int device_id = args[2];
TVMContext ctx{static_cast<DLDeviceType>(device_type), device_id};
nnvm::Graph g = static_cast<nnvm::Graph*>(graph_handle)[0];
*rv = CreateExecutor(g, ctx);
});
TVM_REGISTER_GLOBAL("tvm_graph._get_module_from_graph")
.set_body([](TVMArgs args, TVMRetValue *rv) {
void* graph_handle = args[0];
nnvm::Graph* g = static_cast<nnvm::Graph*>(graph_handle);
*rv = g->MoveCopyAttr<tvm::runtime::Module>("module");
});
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file graph_handle.cc
*/
#include <tvm/packed_func_ext.h>
#include "./graph_handle.h"
namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GraphHandleNode>([](const GraphHandleNode *op, IRPrinter *p) {
p->stream << "graph-handle("
<< "handle=0x" << std::hex
<< reinterpret_cast<uint64_t>(op->graph_handle) << ")";
});
TVM_REGISTER_NODE_TYPE(GraphHandleNode);
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file graph.h
* \brief Data structure about computational graph.
*/
#ifndef TVM_GRAPH_HANDLE_H_
#define TVM_GRAPH_HANDLE_H_
#include <string>
#include <tvm/base.h>
namespace tvm {
/*!
* \brief Computational graph handle.
* Use GraphHandle as its container type
*/
struct GraphHandleNode : public Node {
void *graph_handle;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("graph_handle", &graph_handle);
}
static constexpr const char* _type_key = "GraphHandle";
TVM_DECLARE_NODE_TYPE_INFO(GraphHandleNode, Node);
};
/*! \brief Defines graph handle */
TVM_DEFINE_NODE_REF(GraphHandle, GraphHandleNode);
} // namespace tvm
#endif // TVM_GRAPH_HANDLE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file op_attr_types.h
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef TVM_OP_ATTR_TYPES_H_
#define TVM_OP_ATTR_TYPES_H_
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/graph.h>
#include <vector>
#include <string>
namespace tvm {
namespace contrib {
using runtime::PackedFunc;
using nnvm::StorageVector;
using nnvm::ShapeVector;
using nnvm::DTypeVector;
using nnvm::TShape;
using nnvm::NodeAttrs;
/*! \brief DLPack compatible data types */
using DLTypeVector = std::vector<DLDataType>;
/*!
* \brief Computation description interface
* \param attrs The attribute of the node.
* \param inputs The input tensors(placeholders)
* \return The output description of the tensor.
*/
using FTVMCompute = std::function<
Array<Tensor>
(const NodeAttrs& attrs, const Array<Tensor>& inputs)>;
/*!
* \brief Build the computation schedule for
* op whose root is at current op.
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return schedule The computation schedule.
*/
using FTVMSchedule = std::function<
Schedule(const NodeAttrs& attrs,
const Array<Tensor>& outs,
const std::string& target)>;
/*! \brief Layout Information. */
using TLayoutInfo = std::string;
/*!
* \brief The producer consumer function of node layout
* \param attrs The attribute of the node.
* \param ilayouts The input layouts that the node request.
* \param olayouts The output layouts that the node produce.
* \return bool The success flag.
*/
using FTVMLayoutRequest = std::function<bool (const NodeAttrs& attrs,
std::vector<TLayoutInfo> *ilayouts,
std::vector<TLayoutInfo> *olayouts)>;
/*! \brief The default layout. */
const TLayoutInfo& GetDefaultLayout();
/*! \brief Parameters of layout transform operator */
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
std::string src_layout;
std::string dst_layout;
DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout);
DMLC_DECLARE_FIELD(dst_layout);
}
};
/*! \brief Transform from normal operator to vectorized operator */
using FTVMVectorizedOp = std::function<nnvm::NodePtr (const nnvm::Node*)>;
// The storage result of op
enum OpPatternKind : int {
// Elementwise operation
kElemWise,
// Broadcast operation
kBroadcast,
// Complex operation, can fuse bcast in input/outputs
// but cannot chain another complex op
kComplex,
// Extern operation, cannot fuse anything.
kExtern
};
using TOpPattern = int;
/*!
* \brief Get PackedFunction from global registry and
* report error if it does not exist
* \param name The name of the function.
* \return The created PackedFunc.
*/
inline const PackedFunc& GetPackedFunc(const std::string& name) {
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
return *pf;
}
/*!
* \brief Create a Graph execution module by a given graph and the code module.
* \param g The graph to be executed.
* \param m The tvm module containing the functions.
* \return The created executor module.
*/
tvm::runtime::Module CreateExecutor(nnvm::Graph g);
} // namespace contrib
} // namespace tvm
#endif // TVM_OP_ATTR_TYPES_H_
/*!
* Copyright (c) 2017 by Contributors
* \file Operator Declarations.
*/
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include "./op_attr_types.h"
namespace tvm {
namespace contrib {
using namespace nnvm;
inline bool SameShape(const NodeAttrs& attrs,
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false;
for (TShape& pshape : *oshape) {
pshape = (*ishape)[0];
}
for (TShape& pshape : *ishape) {
pshape = (*ishape)[0];
}
return true;
}
NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
.set_attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(__add_symbol__)
.describe("add two data together")
.set_num_inputs(2)
.include("ElementwiseOpAttr");
NNVM_REGISTER_OP(exp)
.describe("Take exp")
.set_num_inputs(1)
.include("ElementwiseOpAttr");
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file Operator defintions in TVM.
*/
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include "./op_attr_types.h"
namespace tvm {
namespace contrib {
using namespace nnvm;
Array<Tensor>
ComputeAdd(const NodeAttrs& attrs,
const Array<Tensor>& inputs) {
static const PackedFunc& pf = GetPackedFunc("tvm_graph.compute.add");
CHECK_EQ(inputs.size(), 2U);
Tensor ret = pf(inputs[0], inputs[1]);
return {ret};
}
Array<Tensor>
ComputeExp(const NodeAttrs& attrs,
const Array<Tensor>& inputs) {
static const PackedFunc& pf = GetPackedFunc("tvm_graph.compute.exp");
CHECK_EQ(inputs.size(), 1U);
Tensor ret = pf(inputs[0]);
return {ret};
}
Schedule ScheduleEWise(const NodeAttrs& attrs,
const Array<Tensor>& outs,
const std::string& target) {
static const PackedFunc& pf = GetPackedFunc("tvm_graph.schedule.ewise");
return pf(outs, target);
}
NNVM_REGISTER_OP(__add_symbol__)
.set_attr<FTVMCompute>("FTVMCompute", ComputeAdd)
.set_attr<FTVMSchedule>("FTVMSchedule", ScheduleEWise);
NNVM_REGISTER_OP(exp)
.set_attr<FTVMCompute>("FTVMCompute", ComputeExp)
.set_attr<FTVMSchedule>("FTVMSchedule", ScheduleEWise);
} // namespace contrib
} // namespace tvm
import tvm_graph as tg
import numpy as np
import tvm
def test_compile():
x = tg.Variable('x')
y = tg.Variable('y')
z = tg.exp(y + x)
shape = (10, 128)
dtype = tvm.float32
g = tg.build(z, "llvm",
shape={'x': shape,
'y': shape})
m = tg.bind(g, tvm.cpu(0))
# get member functions
set_input, run, get_output = m['set_input'], m['run'], m['get_output']
na = tvm.nd.array(np.ones(shape).astype(dtype))
nb = tvm.nd.array(np.ones(shape).astype(dtype))
# set inputs
set_input('x', na)
set_input('y', nb)
# execute
run()
# get outputs
out = tvm.nd.array(np.zeros(shape).astype(dtype))
get_output(0, out)
np.testing.assert_allclose(
out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))
if __name__ == "__main__":
test_compile()
import tvm
from tvm.contrib import util, rpc
import tvm_graph as tg
import numpy as np
import os
def test_rpc_executor():
host = 'localhost'
port = 9091
server = rpc.Server(host, port)
tmp = util.tempdir()
sym_fname = tmp.relpath('net.json')
lib_fname = tmp.relpath('net.o')
param_fname = tmp.relpath('net.param')
x = tg.Variable('x')
y = tg.Variable('y')
sym = tg.exp(y + x) + tg.exp(x + y)
shape = (10, 128)
dtype = tvm.float32
na = tvm.nd.array(np.ones(shape).astype(dtype))
nb = tvm.nd.array(np.ones(shape).astype(dtype))
tg.save_params(param_fname, {'x': na, 'y': nb})
remote = rpc.connect(host, port)
ctx = remote.cpu(0)
target = "llvm"
shapes = {'x': shape, 'y': shape}
sym_json = tg.compile_graph(lib_fname, sym, target, shapes)
remote.upload(lib_fname)
param_blob = bytearray(open(param_fname, "rb").read())
rm = tg.remote_load_exec(remote,
sym_json,
os.path.basename(lib_fname),
param_blob,
ctx)
run, get_output = rm['run'], rm['get_output']
nc = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
run()
get_output(0, nc)
npa = na.asnumpy()
npb = nb.asnumpy()
np.testing.assert_allclose(nc.asnumpy(),
np.exp(npa + npb) + np.exp(npb + npa))
server.terminate()
if __name__ == "__main__":
test_rpc_executor()
import tvm_graph as tg
import numpy as np
import tvm
def test_save_load():
shape = (10, 128)
dtype = tvm.float32
na = tvm.nd.array(np.ones(shape).astype(dtype))
nb = tvm.nd.array(np.ones(shape).astype(dtype))
x = tg.Variable('x')
y = tg.Variable('y')
z = tg.exp(y + x)
g = tg.build(z, "llvm", shape={'x': shape, 'y': shape})
m0 = tg.bind(g, tvm.cpu(0))
set_input0, run0, get_output0 = m0['set_input'], m0['run'], m0['get_output']
set_input0(0, na)
set_input0(1, nb)
run0()
out0 = tvm.nd.array(np.zeros(shape).astype(dtype))
get_output0(0, out0)
tg.save_params('test.params', {'x': na, 'y': nb})
# create another executor
m1 = tg.bind(g, tvm.cpu(0))
load_params1 = m1['load_params']
load_params1(bytearray(open('test.params', 'rb').read()))
run1, get_output1 = m1['run'], m1['get_output']
run1()
out1 = tvm.nd.array(np.zeros(shape).astype(dtype))
get_output1(0, out1)
np.testing.assert_allclose(out0.asnumpy(), out1.asnumpy())
if __name__ == "__main__":
test_save_load()
......@@ -525,27 +525,20 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 1U);
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
uint64_t num_signature = op->args[1].as<UIntImm>()->value;
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
for (size_t i = 1; i < op->args.size(); ++i) {
std::vector<llvm::Type*> sig_type;
for (size_t i = 2; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
if (i - 2 < num_signature) {
sig_type.push_back(arg_value.back()->getType());
}
}
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, arg_type);
return builder_->CreateCall(f, arg_value);
} else if (op->is_intrinsic("llvm_builtin")) {
CHECK_GE(op->args.size(), 1U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
std::vector<llvm::Value*> arg_value;
for (size_t i = 1; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i]));
}
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
module_.get(), id, sig_type);
return builder_->CreateCall(f, arg_value);
} else if (op->is_intrinsic(Call::bitwise_and)) {
return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
......
......@@ -16,25 +16,8 @@ namespace llvm {
using namespace ir;
template<unsigned id>
inline void DispatchLLVMBuildin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(UIntImm::make(UInt(32), id));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = Call::make(
call->type, "llvm_builtin", cargs, Call::Intrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
.set_body(DispatchLLVMBuildin<::llvm::Intrinsic::prefetch>);
template<unsigned id>
// num_signature means number of arguments used to query signature
template<unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
......@@ -42,6 +25,8 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Array<Expr> cargs;
// intrin id.
cargs.push_back(UIntImm::make(UInt(32), id));
cargs.push_back(UIntImm::make(UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
......@@ -49,7 +34,7 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
call->type, "llvm_intrin", cargs, Call::PureIntrinsic);
}
template<unsigned id>
template<unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
......@@ -57,6 +42,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Array<Expr> cargs;
// intrin id.
cargs.push_back(UIntImm::make(UInt(32), id));
cargs.push_back(UIntImm::make(UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
......@@ -64,20 +50,23 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
call->type, "llvm_intrin", cargs, Call::Intrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 0>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>);
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd>);
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>);
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt>);
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow>);
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>);
} // namespace llvm
} // namespace codegen
......
#!/bin/bash
export PYTHONPATH=python:apps/extension/python
export PYTHONPATH=${PYTHONPATH}:apps/graph_executor/python:apps/graph_executor/nnvm/python
export LD_LIBRARY_PATH=lib:${LD_LIBRARY_PATH}
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
......@@ -14,12 +13,6 @@ make || exit -1
cd ../..
python -m nose -v apps/extension/tests || exit -1
# Test NNVM integration
cd apps/graph_executor
make || exit -1
cd ../..
python -m nose -v apps/graph_executor/tests || exit -1
TVM_FFI=cython python -m nose -v tests/python/integration || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/integration || exit -1
TVM_FFI=cython python -m nose -v tests/python/contrib || exit -1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment