Commit 145b3d0f by Tianqi Chen Committed by GitHub

[RUNTIME] Minimum graph runtime (#484)

* [RUNTIME] Minimum graph runtime

* update docs
parent 65f87264
......@@ -23,7 +23,8 @@ endif()
tvm_option(USE_CUDA "Build with CUDA" OFF)
tvm_option(USE_OPENCL "Build with OpenCL" OFF)
tvm_option(USE_METAL "Build with Metal" OFF)
tvm_option(USE_RPC "Build with RPC" OFF)
tvm_option(USE_RPC "Build with RPC" ON)
tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON)
tvm_option(USE_LLVM "Build with LLVM" OFF)
tvm_option(USE_RTTI "Build with RTTI" ON)
tvm_option(USE_MSVC_MT "Build with MT" OFF)
......
......@@ -90,6 +90,7 @@ stage('Build') {
echo USE_OPENCL=1 >> config.mk
echo LLVM_CONFIG=llvm-config-4.0 >> config.mk
echo USE_RPC=1 >> config.mk
echo USE_GRAPH_RUNTIME=1 >> config.mk
echo USE_BLAS=openblas >> config.mk
rm -f lib/libtvm_runtime.so lib/libtvm.so
"""
......
......@@ -53,6 +53,7 @@ CUDA_SRC = $(wildcard src/runtime/cuda/*.cc)
ROCM_SRC = $(wildcard src/runtime/rocm/*.cc)
OPENCL_SRC = $(wildcard src/runtime/opencl/*.cc)
RPC_SRC = $(wildcard src/runtime/rpc/*.cc)
GRAPH_SRC = $(wildcard src/runtime/graph/*.cc)
RUNTIME_SRC = $(wildcard src/runtime/*.cc)
# Objectives
......@@ -63,6 +64,7 @@ CUDA_OBJ = $(patsubst src/%.cc, build/%.o, $(CUDA_SRC))
ROCM_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCM_SRC))
OPENCL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENCL_SRC))
RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC))
GRAPH_OBJ = $(patsubst src/%.cc, build/%.o, $(GRAPH_SRC))
CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC)) $(LLVM_OBJ)
RUNTIME_OBJ = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))
CONTRIB_OBJ =
......@@ -124,6 +126,10 @@ ifeq ($(USE_RPC), 1)
RUNTIME_DEP += $(RPC_OBJ)
endif
ifeq ($(USE_GRAPH_RUNTIME), 1)
RUNTIME_DEP += $(GRAPH_OBJ)
endif
include make/contrib/cblas.mk
include make/contrib/nnpack.mk
include make/contrib/cudnn.mk
......
......@@ -22,6 +22,11 @@ tvm.contrib.rpc
.. automodule:: tvm.contrib.rpc
:members:
tvm.contrib.graph
~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.graph
:members:
tvm.contrib.util
~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.util
......
......@@ -45,7 +45,10 @@ USE_OPENCL = 0
USE_METAL = 0
# Whether enable RPC during compile
USE_RPC = 0
USE_RPC = 1
# Whether enable tiny embedded graph runtime.
USE_GRAPH_RUNTIME = 1
# whether build with LLVM support
# Requires LLVM version >= 4.0
......
"""Minimum graph runtime that executes graph containing TVM PackedFunc."""
from . import rpc
from .._ffi.base import string_types
from .._ffi.function import get_global_func
from .. import ndarray as nd
def create(graph_json_str, libmod, ctx):
"""Create a runtime executor module given a graph and module.
Parameters
----------
graph_json_str : str or graph class
The graph to be deployed in json format output by nnvm graph.
The graph can only contain one operator(tvm_op) that
points to the name of PackedFunc in the libmod.
libmod : tvm.Module
The module of the corresponding function
ctx : TVMContext
The context to deploy the module, can be local or remote.
Returns
-------
graph_module : GraphModule
Runtime graph module that can be used to execute the graph.
"""
if not isinstance(graph_json_str, string_types):
try:
graph_json_str = graph_json_str._tvm_graph_json()
except AttributeError:
raise ValueError("Type %s is not supported" % type(graph_json_str))
device_type = ctx.device_type
device_id = ctx.device_id
if device_type >= rpc.RPC_SESS_MASK:
assert libmod.type_key == "rpc"
assert rpc._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index
hmod = rpc._ModuleHandle(libmod)
fcreate = ctx._rpc_sess.get_function("tvm.graph_runtime.remote_create")
device_type = device_type % rpc.RPC_SESS_MASK
return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx)
fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx)
class GraphModule(object):
"""Wrapper runtime module.
This is a thin wrapper of the underlying TVM module.
you can also directly call set_input, run, and get_output
of underlying module functions
Parameters
----------
module : Module
The interal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under
Attributes
----------
module : Module
The interal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under
"""
def __init__(self, module, ctx):
self.module = module
self._set_input = module["set_input"]
self._run = module["run"]
self._get_output = module["get_output"]
self.ctx = ctx
def set_input(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs
Parameters
----------
key : int or str
The input key
value : the input value.
The input key
params : dict of str to NDArray
Additonal arguments
"""
if key:
self._set_input(key, nd.array(value, ctx=self.ctx))
for k, v in params.items():
self._set_input(k, nd.array(v, ctx=self.ctx))
return self
def run(self, **input_dict):
"""Run forward execution of the graph
Parameters
----------
input_dict: dict of str to NDArray
List of input values to be feed to
"""
if input_dict:
self.set_input(**input_dict)
self._run()
def get_output(self, index, out):
"""Get index-th output to out
Parameters
----------
index : int
The input index
out : NDArray
The output array container
"""
self._get_output(index, out)
return out
def __getitem__(self, key):
"""Get internal module function
Parameters
----------
key : str
The key to the module.
"""
return self.module[key]
/*!
* Copyright (c) 2017 by Contributors
* \file graph_runtime.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <numeric>
#include "./graph_runtime.h"
namespace tvm {
namespace runtime {
/*! \brief macro to do C API call */
#define TVM_CCALL(func) \
{ \
int ret = (func); \
CHECK_EQ(ret, 0) \
<< TVMGetLastError(); \
}
/*!
* \brief Tiny graph runtime.
*
* This runtime can be acccesibly in various language via
* TVM runtime PackedFunc API.
*/
class GraphRuntime : public ModuleNode {
public:
~GraphRuntime() {
for (DLTensor* t : storage_pool_) {
TVM_CCALL(TVMArrayFree(t));
}
}
/*!
* \brief Get member function to front-end
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
/*!
* \return The type key of the executor.
*/
const char* type_key() const final {
return "GraphRuntime";
}
void Run() {
// setup the array and requirements.
for (size_t i = 0; i < op_execs_.size(); ++i) {
if (op_execs_[i]) op_execs_[i]();
}
}
/*!
* \brief Initialize the graph executor with graph and context.
* \param graph The execution graph.
* \param module The module containing the compiled functions.
* \param ctx The context where the graph should sit on
*/
void Init(const std::string& graph_json,
tvm::runtime::Module module,
TVMContext ctx) {
std::istringstream is(graph_json);
dmlc::JSONReader reader(&is);
this->Load(&reader);
module_ = module;
ctx_ = ctx;
this->SetupStorage();
this->SetupOpExecs();
}
/*!
* \brief Get the input index given the name of input.
* \param name The name of the input.
* \return The index of input.
*/
int GetInputIndex(const std::string& name) {
for (size_t i = 0; i< input_nodes_.size(); ++i) {
uint32_t nid = input_nodes_[i];
if (nodes_[nid].name == name) {
return static_cast<int>(i);
}
}
LOG(FATAL) << "cannot find " << name << " among input";
return -1;
}
/*!
* \brief set index-th input to the graph.
* \param index The input index.
* \param data The input data.
*/
void SetInput(int index, DLTensor* data_in) {
CHECK_LT(static_cast<size_t>(index), input_nodes_.size());
uint32_t eid = this->entry_id(input_nodes_[index], 0);
TVM_CCALL(TVMArrayCopyFromTo(data_in, &data_entry_[eid], nullptr));
}
/*!
* \brief Copy index-th output to data_out.
* \param index The output index.
* \param data_out the output data.
*/
void GetOutput(int index, DLTensor* data_out) {
CHECK_LT(static_cast<size_t>(index), outputs_.size());
uint32_t eid = this->entry_id(outputs_[index]);
TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr));
}
/*!
* \brief Load parameters from binary stream
* \param strm The input stream.
*/
void LoadParams(dmlc::Stream* strm);
/*!
* \brief Load parameters from parameter blob.
* \param param_blob A binary blob of parameter.
*/
void LoadParams(const std::string& param_blob) {
dmlc::MemoryStringStream strm(const_cast<std::string*>(&param_blob));
this->LoadParams(&strm);
}
private:
// Node entry
struct NodeEntry {
uint32_t node_id;
uint32_t index;
uint32_t version;
// JSON Loader
void Load(dmlc::JSONReader *reader) {
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&node_id);
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&index);
if (reader->NextArrayItem()) {
reader->Read(&version);
CHECK(!reader->NextArrayItem()) << "invalid json format";
} else {
version = 0;
}
}
};
// Node
struct Node {
// operator type in string
std::string op_type;
// name of the op
std::string name;
// parameters
TVMOpParam param;
// inputs
std::vector<NodeEntry> inputs;
// control deps
std::vector<uint32_t> control_deps;
// JSON Loader
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
std::unordered_map<std::string, std::string> dict;
int bitmask = 0;
std::string key;
while (reader->NextObjectItem(&key)) {
if (key == "op") {
reader->Read(&op_type);
bitmask |= 1;
} else if (key == "name") {
reader->Read(&name);
bitmask |= 2;
} else if (key == "inputs") {
reader->Read(&inputs);
bitmask |= 4;
} else if (key == "attr" || key == "attrs") {
reader->Read(&dict);
param.Init(dict);
} else if (key == "control_deps") {
reader->Read(&control_deps);
} else {
LOG(FATAL) << "do not support key " << key;
}
}
CHECK_EQ(bitmask, 1|2|4) << "invalid format";
}
};
struct GraphAttr {
size_t storage_num_not_alloctaed{0};
std::vector<int> storage_id;
std::vector<std::string> dltype;
std::vector<std::vector<int64_t> > shape;
// The graph attribute fields.
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
int bitmask = 0;
std::string key, type;
while (reader->NextObjectItem(&key)) {
if (key == "dltype") {
reader->BeginArray();
CHECK(reader->NextArrayItem());
reader->Read(&type);
CHECK_EQ(type, "list_str");
CHECK(reader->NextArrayItem());
reader->Read(&dltype);
CHECK(!reader->NextArrayItem());
bitmask |= 1;
} else if (key == "storage_id") {
reader->BeginArray();
CHECK(reader->NextArrayItem());
reader->Read(&type);
CHECK_EQ(type, "list_int");
CHECK(reader->NextArrayItem());
reader->Read(&storage_id);
CHECK(!reader->NextArrayItem());
bitmask |= 2;
} else if (key == "shape") {
reader->BeginArray();
CHECK(reader->NextArrayItem());
reader->Read(&type);
CHECK_EQ(type, "list_shape");
CHECK(reader->NextArrayItem());
reader->Read(&shape);
CHECK(!reader->NextArrayItem());
bitmask |= 4;
} else {
reader->BeginArray();
CHECK(reader->NextArrayItem());
reader->Read(&type);
if (type == "list_int") {
CHECK(reader->NextArrayItem());
std::vector<int> temp;
reader->Read(&temp);
} else if (type == "size_t") {
CHECK(reader->NextArrayItem());
size_t temp;
reader->Read(&temp);
} else {
LOG(FATAL) << "cannot skip graph attr " << key;
}
CHECK(!reader->NextArrayItem());
}
}
CHECK_EQ(bitmask, 1|2|4) << "invalid format";
}
};
// The graph attribute fields.
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
int bitmask = 0;
std::string key;
while (reader->NextObjectItem(&key)) {
if (key == "nodes") {
reader->Read(&nodes_);
bitmask |= 1;
} else if (key == "arg_nodes") {
reader->Read(&input_nodes_);
bitmask |= 2;
} else if (key == "node_row_ptr") {
reader->Read(&node_row_ptr_);
bitmask |= 4;
} else if (key == "heads") {
reader->Read(&outputs_);
bitmask |= 8;
} else if (key == "attrs") {
reader->Read(&attrs_);
bitmask |= 16;
}
}
CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format";
}
bool LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor);
/*! \brief Setup the temporal storage */
void SetupStorage();
/*! \brief Setup the executors */
void SetupOpExecs();
/*!
* \brief Create a executtion function given input.
* \param attrs The node attributes
* \param args The arguments to the functor, including inputs and outputs.
* \param num_inputs Number of inputs
* \return The created executor.
*/
std::function<void()> CreateTVMOp(const TVMOpParam& attrs,
const std::vector<DLTensor>& args,
size_t num_inputs);
// Get node entry index.
uint32_t entry_id(uint32_t nid, uint32_t index) const {
return node_row_ptr_[nid] + index;
}
// Get node entry index.
uint32_t entry_id(const NodeEntry& e) const {
return entry_id(e.node_id, e.index);
}
// Number of node entries
uint32_t num_node_entries() const {
return node_row_ptr_.back();
}
// Number of nodes.
uint32_t num_nodes() const {
return static_cast<uint32_t>(nodes_.size());
}
// The graph nodes.
std::vector<Node> nodes_;
// The argument nodes.
std::vector<uint32_t> input_nodes_;
// used or quick entry indexing
std::vector<uint32_t> node_row_ptr_;
// output entries
std::vector<NodeEntry> outputs_;
// Additional graph attributes
GraphAttr attrs_;
/*! \brief The code module */
tvm::runtime::Module module_;
/*! \brief execution context */
TVMContext ctx_;
/*! \brief common storage pool */
std::vector<DLTensor*> storage_pool_;
/*! \brief data entry of each node */
std::vector<DLTensor> data_entry_;
/*! \brief operator on each node */
std::vector<std::function<void()> > op_execs_;
};
DMLC_REGISTER_PARAMETER(TVMOpParam);
bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved)))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->ctx, sizeof(tensor->ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->ndim, sizeof(tensor->ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->dtype, sizeof(tensor->dtype)))
<< "Invalid DLTensor file format";
int ndim = tensor->ndim;
CHECK(strm->Read(tensor->shape, sizeof(int64_t) * ndim))
<< "Invalid DLTensor file format";
int64_t size = 1;
int type_size = tensor->dtype.bits / 8;
for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size)))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == type_size * size)
<< "Invalid DLTensor file format";
CHECK(strm->Read(tensor->data, type_size * size))
<< "Invalid DLTensor file format";
return true;
}
void GraphRuntime::LoadParams(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
CHECK(header == kTVMNDArrayListMagic)
<< "Invalid parameters file format";
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";
std::vector<std::string> names;
CHECK(strm->Read(&names))
<< "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz, sizeof(sz));
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
uint32_t in_idx = GetInputIndex(names[i]);
CHECK(LoadDLTensor(strm, &data_entry_[this->entry_id(input_nodes_[in_idx], 0)]))
<< "Invalid parameters file format";
}
}
void GraphRuntime::SetupStorage() {
// Grab saved optimization plan from graph.
std::vector<TVMType> vtype;
for (const std::string& s_type : attrs_.dltype) {
vtype.push_back(tvm::runtime::String2TVMType(s_type));
}
data_entry_.resize(num_node_entries());
// Find the maximum space size.
int max_id = 0;
for (size_t i = 0; i < attrs_.shape.size(); ++i) {
max_id = std::max(attrs_.storage_id[i] + 1, max_id);
}
for (uint32_t nid : input_nodes_) {
attrs_.storage_id[this->entry_id(nid, 0)] = max_id++;
}
// size of each storage pool entry
std::vector<size_t> pool_entry_bytes;
// Find the maximum space size.
for (size_t i = 0; i < attrs_.shape.size(); ++i) {
int storage_id = attrs_.storage_id[i];
size_t size = 1;
for (int64_t sz : attrs_.shape[i]) {
size *= static_cast<size_t>(sz);
}
CHECK_GE(storage_id, 0) << "Do not support runtime shape op";
DLDataType t = vtype[i];
size_t bits = t.bits * t.lanes;
CHECK_EQ(bits % 8U, 0U);
size_t bytes = (bits / 8U) * size;
size_t sid = static_cast<size_t>(storage_id);
if (sid >= pool_entry_bytes.size()) {
pool_entry_bytes.resize(sid + 1, 0);
}
pool_entry_bytes[sid] = std::max(pool_entry_bytes[sid], bytes);
}
// Allocate the space.
for (size_t i = 0; i < pool_entry_bytes.size(); ++i) {
int64_t shape[] = {static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4};
DLTensor* tensor;
TVM_CCALL(TVMArrayAlloc(
shape, 1, kFloat, 32, 1, ctx_.device_type, ctx_.device_id, &tensor));
storage_pool_.push_back(tensor);
}
// Assign the pooled entries.
for (size_t i = 0; i < data_entry_.size(); ++i) {
int storage_id = attrs_.storage_id[i];
data_entry_[i] = *storage_pool_[storage_id];
data_entry_[i].shape = const_cast<int64_t*>(attrs_.shape[i].data());
data_entry_[i].ndim = static_cast<int>(attrs_.shape[i].size());
data_entry_[i].dtype = vtype[i];
}
}
void GraphRuntime::SetupOpExecs() {
op_execs_.resize(this->num_nodes());
// setup the array and requirements.
for (uint32_t nid = 0; nid < this->num_nodes(); ++nid) {
const auto& inode = nodes_[nid];
if (inode.op_type == "null") continue;
std::vector<DLTensor> args;
for (const auto& e : inode.inputs) {
args.push_back(data_entry_[this->entry_id(e)]);
}
for (uint32_t index = 0; index < inode.param.num_outputs; ++index) {
uint32_t eid = this->entry_id(nid, index);
args.push_back(data_entry_[eid]);
}
CHECK_EQ(inode.op_type, "tvm_op")
<< "Can only take tvm_op as op";
op_execs_[nid] = CreateTVMOp(inode.param, args, inode.inputs.size());
}
}
std::function<void()> GraphRuntime::CreateTVMOp(
const TVMOpParam& param,
const std::vector<DLTensor>& args,
size_t num_inputs) {
struct OpArgs {
std::vector<DLTensor> args;
std::vector<TVMValue> arg_values;
std::vector<int> arg_tcodes;
std::vector<int64_t> shape_data;
};
std::shared_ptr<OpArgs> arg_ptr = std::make_shared<OpArgs>();
// setup address.
arg_ptr->args = std::move(args);
if (param.flatten_data) {
arg_ptr->shape_data.resize(arg_ptr->args.size());
}
for (size_t i = 0; i < arg_ptr->args.size(); ++i) {
TVMValue v;
DLTensor* t = &(arg_ptr->args[i]);
v.v_handle = t;
arg_ptr->arg_values.push_back(v);
arg_ptr->arg_tcodes.push_back(kArrayHandle);
if (param.flatten_data) {
arg_ptr->shape_data[i] = std::accumulate(
t->shape, t->shape + t->ndim, 1, std::multiplies<int64_t>());
t->ndim = 1;
t->shape = &(arg_ptr->shape_data[i]);
}
}
// get compiled function from module.
tvm::runtime::PackedFunc pf = module_.GetFunction(param.func_name, false);
CHECK(pf != nullptr) << "no such function in module: " << param.func_name;
auto fexec = [arg_ptr, pf] () {
TVMRetValue rv;
TVMArgs targs(arg_ptr->arg_values.data(),
arg_ptr->arg_tcodes.data(),
static_cast<int>(arg_ptr->arg_values.size()));
pf.CallPacked(targs, &rv);
};
return fexec;
}
PackedFunc GraphRuntime::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
// return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) {
this->SetInput(this->GetInputIndex(args[0]), args[1]);
} else {
this->SetInput(args[0], args[1]);
}
});
} else if (name == "get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->GetOutput(args[0], args[1]);
});
} else if (name == "run") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Run();
});
} else if (name == "load_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0].operator std::string());
});
} else {
return PackedFunc();
}
}
Module GraphRuntimeCreate(std::string sym_json,
tvm::runtime::Module m,
int device_type,
int device_id) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
std::shared_ptr<GraphRuntime> exec = std::make_shared<GraphRuntime>();
exec->Init(sym_json, m, ctx);
return Module(exec);
}
TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = GraphRuntimeCreate(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[1];
*rv = GraphRuntimeCreate(args[0],
*static_cast<tvm::runtime::Module*>(mhandle),
args[2], args[3]);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Tiny graph runtime that can run graph
* containing only tvm PackedFunc.
* \file graph_runtime.h
*/
#ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
#define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
#include <dmlc/parameter.h>
#include <string>
namespace tvm {
namespace runtime {
/*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
/*! \brief operator attributes about tvm op */
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs).set_default(1);
DMLC_DECLARE_FIELD(num_outputs).set_default(1);
DMLC_DECLARE_FIELD(flatten_data).set_default(0);
}
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
import tvm
import numpy as np
import json
from tvm.contrib import rpc, util, graph_runtime
def test_graph_simple():
n = 4
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
node0 = {"op": "null", "name": "x", "inputs": []}
node1 = {"op": "tvm_op", "name": "add",
"inputs": [[0, 0, 0]],
"attrs": {"func_name": "myadd",
"flatten_data": "1",
"num_inputs" : "1",
"num_outputs" : "1"}}
nodes = [node0, node1]
arg_nodes = [0]
node_row_ptr = [0, 1, 2]
outputs = [[1, 0, 0]]
shape = (4,)
attrs = {
"shape" : ["list_shape", [shape, shape]],
"dltype" : ["list_str", ["float32", "float32"]],
"storage_id" : ["list_int", [0, 1]],
}
graph = {"nodes": nodes,
"arg_nodes": arg_nodes,
"node_row_ptr": node_row_ptr,
"heads": outputs,
"attrs": attrs}
graph = json.dumps(graph)
def check_verify():
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
mlib = tvm.build(s, [A, B], "llvm", name="myadd")
mod = graph_runtime.create(graph, mlib, tvm.cpu(0))
a = np.random.uniform(size=(n,)).astype(A.dtype)
mod.run(x=a)
out = mod.get_output(0, tvm.nd.empty((n,)))
np.testing.assert_equal(out.asnumpy(), a + 1)
def check_remote():
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
mlib = tvm.build(s, [A, B], "llvm", name="myadd")
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
temp = util.tempdir()
ctx = remote.cpu(0)
path_dso = temp.relpath("dev_lib.so")
mlib.export_library(path_dso)
remote.upload(path_dso)
mlib = remote.load_module("dev_lib.so")
mod = graph_runtime.create(graph, mlib, remote.cpu(0))
a = np.random.uniform(size=(n,)).astype(A.dtype)
mod.run(x=tvm.nd.array(a, ctx))
out = tvm.nd.empty((n,), ctx=ctx)
out = mod.get_output(0, out)
np.testing.assert_equal(out.asnumpy(), a + 1)
check_verify()
check_remote()
if __name__ == "__main__":
test_graph_simple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment