Commit 7c3ec7df by Zhi Committed by Tianqi Chen

Heterogeneous Runtime (#1695)

parent 7beafddd
......@@ -384,8 +384,14 @@ def build(sch,
target=None,
target_host=None,
name="default_function",
binds=None):
"""Build a function with arguments as signiture.
binds=None,
postpone_host_codegen=False):
"""Build a function with arguments as signature. Code will be generated
for a device specified by the target. For homogeneous execution, a module
that contains both host and device code is returned. For heterogeneous
execution, a list of lowered functions for the host and a module containing
device code are returned, but actual code generation for the host module is
postponed after code generation is finished for all devices.
Parameters
----------
......@@ -414,10 +420,18 @@ def build(sch,
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
postpone_host_codegen : bool, optional
A bool value that indicates if code generation for the host module
should be postponed. This variable is set to be true for heterogeneous
execution. Otherwise, it is defaulted to false.
Returns
-------
f : Function, or pair of functions
The result function.
ret : tvm.module, or (list of LoweredFunc, tvm.module) tuple
A module that combines both host and device code is returned when
postpone_host_codegen is not set. Otherwise, a list of lowered
functions for the host and a module contains only device code are
returned.
Note
----
......@@ -498,9 +512,15 @@ def build(sch,
fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mhost = codegen.build_module(fhost, str(target_host))
# Append fhost to the device module and return the updated module. All
# device modules will be imported to the host module after all of them are
# collected.
mdev = codegen.build_module(fdevice, str(target_device)) if fdevice else None
if postpone_host_codegen:
return fhost, mdev
mhost = codegen.build_module(fhost, str(target_host))
if fdevice:
mdev = codegen.build_module(fdevice, str(target_device))
mhost.import_module(mdev)
return mhost
......@@ -3,26 +3,24 @@ import numpy as np
from .._ffi.base import string_types
from .._ffi.function import get_global_func
from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base
from .. import ndarray as nd
def create(graph_json_str, libmod, ctx):
"""Create a runtime executor module given a graph and module.
Parameters
----------
graph_json_str : str or graph class
The graph to be deployed in json format output by nnvm graph.
The graph can only contain one operator(tvm_op) that
points to the name of PackedFunc in the libmod.
libmod : tvm.Module
The module of the corresponding function
ctx : TVMContext
The context to deploy the module, can be local or remote.
ctx : TVMContext or list of TVMContext
The context to deploy the module. It can be local or remote when there
is only one TVMContext. Otherwise, the first context in the list will
be used as this purpose. All context should be given for heterogeneous
execution.
Returns
-------
graph_module : GraphModule
......@@ -33,17 +31,42 @@ def create(graph_json_str, libmod, ctx):
graph_json_str = graph_json_str._tvm_graph_json()
except AttributeError:
raise ValueError("Type %s is not supported" % type(graph_json_str))
device_type = ctx.device_type
device_id = ctx.device_id
if device_type >= rpc_base.RPC_SESS_MASK:
assert libmod.type_key == "rpc"
assert rpc_base._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index
if isinstance(ctx, TVMContext):
ctx = [ctx]
elif not isinstance(ctx, (list, tuple)):
raise ValueError("ctx has to be the type of TVMContext or a list of "
"TVMCTVMContext")
for cur_ctx in ctx:
if not isinstance(cur_ctx, TVMContext):
raise ValueError("ctx has to be the type of TVMContext or a list "
"of TVMContext")
# device_type_id[0], device_type_id[1] are used as the primary/fallback
# context type and id. All other ones are used as device context for
# heterogeneous execution.
num_rpc_ctx = 0
device_type_id = []
for cur_ctx in ctx:
device_type = cur_ctx.device_type
if device_type >= rpc_base.RPC_SESS_MASK:
assert libmod.type_key == "rpc"
assert rpc_base._SessTableIndex(
libmod) == cur_ctx._rpc_sess._tbl_index
num_rpc_ctx += 1
device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK
device_type_id.append(device_type)
device_type_id.append(cur_ctx.device_id)
if 0 < num_rpc_ctx < len(ctx):
raise ValueError("Either all or none of the contexts should be rpc.")
if num_rpc_ctx == len(ctx):
hmod = rpc_base._ModuleHandle(libmod)
fcreate = ctx._rpc_sess.get_function("tvm.graph_runtime.remote_create")
device_type = device_type % rpc_base.RPC_SESS_MASK
return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx)
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create")
return GraphModule(fcreate(graph_json_str, hmod, *device_type_id))
fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx)
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
class GraphModule(object):
......@@ -58,18 +81,13 @@ class GraphModule(object):
module : Module
The interal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under
Attributes
----------
module : Module
The interal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under
"""
def __init__(self, module, ctx):
def __init__(self, module):
self.module = module
self._set_input = module["set_input"]
self._run = module["run"]
......@@ -81,7 +99,6 @@ class GraphModule(object):
except AttributeError:
pass
self._load_params = module["load_params"]
self.ctx = ctx
def set_input(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs
......@@ -98,14 +115,14 @@ class GraphModule(object):
Additonal arguments
"""
if key:
self._set_input(key, nd.array(value, ctx=self.ctx))
self._get_input(key).copyfrom(value)
if params:
# upload big arrays first to avoid memory issue in rpc mode
keys = list(params.keys())
keys.sort(key=lambda x: -np.prod(params[x].shape))
for k in keys:
self._set_input(k, nd.array(params[k], ctx=self.ctx))
self._get_input(k).copyfrom(params[k])
def run(self, **input_dict):
"""Run forward execution of the graph
......@@ -177,7 +194,8 @@ class GraphModule(object):
if hasattr(self, '_debug_get_output'):
self._debug_get_output(node, out)
else:
raise RuntimeError("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0")
raise RuntimeError(
"Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0")
return out
def load_params(self, params_bytes):
......
......@@ -2,22 +2,26 @@
* Copyright (c) 2017 by Contributors
* \file graph_runtime.cc
*/
#include "graph_runtime.h"
#include <dlpack/dlpack.h>
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/device_api.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <numeric>
#include <tvm/runtime/serializer.h>
#include <algorithm>
#include <vector>
#include <functional>
#include "graph_runtime.h"
#include <numeric>
#include <vector>
namespace tvm {
namespace runtime {
/*! \brief macro to do C API call */
/*! \brief Macro to do C API call. */
#define TVM_CCALL(func) \
{ \
int ret = (func); \
......@@ -34,7 +38,7 @@ namespace runtime {
class GraphRuntime : public ModuleNode {
public:
/*!
* \brief Get member function to front-end
* \brief Get member function to front-end.
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
......@@ -58,12 +62,13 @@ class GraphRuntime : public ModuleNode {
/*!
* \brief Initialize the graph executor with graph and context.
* \param graph_json The execution graph.
* \param module The module containing the compiled functions.
* \param ctx The context where the graph should sit on
* \param module The module containing the compiled functions for the host
* processor.
* \param ctxs The context of the host and devices where graph nodes will be
* executed on.
*/
void Init(const std::string& graph_json,
tvm::runtime::Module module,
TVMContext ctx) {
void Init(const std::string& graph_json, const tvm::runtime::Module& module,
const std::vector<TVMContext>& ctxs) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::istringstream is(graph_json);
#else
......@@ -72,10 +77,11 @@ class GraphRuntime : public ModuleNode {
dmlc::JSONReader reader(&is);
this->Load(&reader);
module_ = module;
ctx_ = ctx;
ctxs_ = ctxs;
this->SetupStorage();
this->SetupOpExecs();
}
/*!
* \brief Get the input index given the name of input.
* \param name The name of the input.
......@@ -92,7 +98,7 @@ class GraphRuntime : public ModuleNode {
return -1;
}
/*!
* \brief set index-th input to the graph.
* \brief Set index-th input to the graph.
* \param index The input index.
* \param data_in The input data.
*/
......@@ -134,7 +140,7 @@ class GraphRuntime : public ModuleNode {
/*!
* \brief Copy index-th output to data_out.
* \param index The output index.
* \param data_out the output data.
* \param data_out The output data.
*/
void CopyOutputTo(int index, DLTensor* data_out) {
CHECK_LT(static_cast<size_t>(index), outputs_.size());
......@@ -172,8 +178,8 @@ class GraphRuntime : public ModuleNode {
* from begining upto the index-th node and return output of index-th node.
* This is costly operation and suggest to use only for debug porpose.
*
* \param index: The index of the node.
* \param data_out the node data.
* \param index The index of the node.
* \param data_out The node data.
*/
void DebugGetNodeOutput(int index, DLTensor* data_out) {
CHECK_LT(static_cast<size_t>(index), nodes_.size());
......@@ -188,7 +194,7 @@ class GraphRuntime : public ModuleNode {
}
#endif
/*!
* \brief Load parameters from binary stream
* \brief Load parameters from binary stream.
* \param strm The input stream.
*/
void LoadParams(dmlc::Stream* strm);
......@@ -202,6 +208,12 @@ class GraphRuntime : public ModuleNode {
}
private:
// Memory pool entry.
struct PoolEntry {
size_t size;
int device_type;
PoolEntry(int s, int dev_type) : size(s), device_type(dev_type) {}
};
// Node entry
struct NodeEntry {
uint32_t node_id;
......@@ -260,7 +272,6 @@ class GraphRuntime : public ModuleNode {
// JSON Loader
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
std::unordered_map<std::string, std::string> dict;
int bitmask = 0;
std::string key;
while (reader->NextObjectItem(&key)) {
......@@ -287,6 +298,7 @@ class GraphRuntime : public ModuleNode {
struct GraphAttr {
size_t storage_num_not_alloctaed{0};
std::vector<int> storage_id;
std::vector<int> device_index;
std::vector<std::string> dltype;
std::vector<std::vector<int64_t> > shape;
// The graph attribute fields.
......@@ -322,6 +334,14 @@ class GraphRuntime : public ModuleNode {
reader->Read(&shape);
CHECK(!reader->NextArrayItem());
bitmask |= 4;
} else if (key == "device_index") {
reader->BeginArray();
CHECK(reader->NextArrayItem());
reader->Read(&type);
CHECK_EQ(type, "list_int");
CHECK(reader->NextArrayItem());
reader->Read(&device_index);
CHECK(!reader->NextArrayItem());
} else {
reader->BeginArray();
CHECK(reader->NextArrayItem());
......@@ -372,13 +392,14 @@ class GraphRuntime : public ModuleNode {
}
/*! \brief Setup the temporal storage */
void SetupStorage();
/*! \brief Setup the executors */
/*! \brief Setup the executors. */
void SetupOpExecs();
/*!
* \brief Create a executtion function given input.
* \param attrs The node attributes
* \param attrs The node attributes.
* \param args The arguments to the functor, including inputs and outputs.
* \param num_inputs Number of inputs
* \param num_inputs Number of inputs.
* \param dev_type The device type of the tvm_op.
* \return The created executor.
*/
std::function<void()> CreateTVMOp(const TVMOpParam& attrs,
......@@ -392,7 +413,7 @@ class GraphRuntime : public ModuleNode {
uint32_t entry_id(const NodeEntry& e) const {
return entry_id(e.node_id, e.index);
}
// Number of node entries
// Number of node entries.
uint32_t num_node_entries() const {
return node_row_ptr_.back();
}
......@@ -400,25 +421,25 @@ class GraphRuntime : public ModuleNode {
uint32_t num_nodes() const {
return static_cast<uint32_t>(nodes_.size());
}
// The graph nodes.
/*! \brief The graph nodes. */
std::vector<Node> nodes_;
// The argument nodes.
/*! \brief The argument nodes. */
std::vector<uint32_t> input_nodes_;
// used or quick entry indexing
/*! \brief Used for quick entry indexing. */
std::vector<uint32_t> node_row_ptr_;
// output entries
/*! \brief Output entries. */
std::vector<NodeEntry> outputs_;
// Additional graph attributes
/*! \brief Additional graph attributes. */
GraphAttr attrs_;
/*! \brief The code module */
/*! \brief The code module that contains both host and device code. */
tvm::runtime::Module module_;
/*! \brief execution context */
TVMContext ctx_;
/*! \brief common storage pool */
/*! \brief Execution context of all devices including the host. */
std::vector<TVMContext> ctxs_;
/*! \brief Common storage pool for all devices. */
std::vector<NDArray> storage_pool_;
/*! \brief data entry of each node */
/*! \brief Data entry of each node. */
std::vector<NDArray> data_entry_;
/*! \brief operator on each node */
/*! \brief Operator on each node. */
std::vector<std::function<void()> > op_execs_;
};
......@@ -458,12 +479,17 @@ void GraphRuntime::SetupStorage() {
for (const std::string& s_type : attrs_.dltype) {
vtype.push_back(tvm::runtime::String2TVMType(s_type));
}
data_entry_.resize(num_node_entries());
// size of each storage pool entry
std::vector<size_t> pool_entry_bytes;
// Size and device type of each storage pool entry.
std::vector<PoolEntry> pool_entry;
// Find the maximum space size.
for (size_t i = 0; i < attrs_.shape.size(); ++i) {
int storage_id = attrs_.storage_id[i];
// Use the fallback device if no device index is available.
int device_type = static_cast<int>(ctxs_[0].device_type);
if (!attrs_.device_index.empty()) {
device_type = attrs_.device_index[i];
}
size_t size = 1;
for (int64_t sz : attrs_.shape[i]) {
size *= static_cast<size_t>(sz);
......@@ -474,23 +500,42 @@ void GraphRuntime::SetupStorage() {
CHECK_EQ(bits % 8U, 0U);
size_t bytes = (bits / 8U) * size;
size_t sid = static_cast<size_t>(storage_id);
if (sid >= pool_entry_bytes.size()) {
pool_entry_bytes.resize(sid + 1, 0);
uint32_t sid = static_cast<uint32_t>(storage_id);
if (sid >= pool_entry.size()) {
pool_entry.resize(sid + 1, {0, -1});
} else {
CHECK(pool_entry[sid].device_type == -1 ||
pool_entry[sid].device_type == device_type)
<< "The same pool entry cannot be assigned to multiple devices";
}
pool_entry_bytes[sid] = std::max(pool_entry_bytes[sid], bytes);
pool_entry[sid].size = std::max(pool_entry[sid].size, bytes);
pool_entry[sid].device_type = device_type;
}
// Allocate the space.
for (size_t i = 0; i < pool_entry_bytes.size(); ++i) {
for (const auto& pit : pool_entry) {
std::vector<int64_t> shape;
shape.push_back(static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4);
storage_pool_.push_back(NDArray::Empty(shape, DLDataType {kDLFloat, 32, 1}, ctx_));
// This for loop is very fast since there are usually only a couple of
// devices available on the same hardware.
const auto& cit =
std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) {
return pit.device_type == static_cast<int>(c.device_type);
});
TVMContext ctx = cit == ctxs_.end() ? ctxs_[0] : *cit;
shape.push_back(static_cast<int64_t>(pit.size + 3) / 4);
storage_pool_.push_back(
NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx));
}
// Assign the pooled entries.
// Assign the pooled entries. A unified memory pool is used to simplifiy
// memory assignment for each node entry. The allocated memory on each device
// is mapped to this pool.
data_entry_.resize(num_node_entries());
for (size_t i = 0; i < data_entry_.size(); ++i) {
int storage_id = attrs_.storage_id[i];
CHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size());
data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]);
data_entry_[i] =
storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]);
}
}
......@@ -508,8 +553,8 @@ void GraphRuntime::SetupOpExecs() {
uint32_t eid = this->entry_id(nid, index);
args.push_back(*(data_entry_[eid].operator->()));
}
CHECK_EQ(inode.op_type, "tvm_op")
<< "Can only take tvm_op as op";
CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op";
op_execs_[nid] = CreateTVMOp(inode.param, args, inode.inputs.size());
}
}
......@@ -543,13 +588,26 @@ std::function<void()> GraphRuntime::CreateTVMOp(
t->shape = &(arg_ptr->shape_data[i]);
}
}
if (param.func_name == "__nop") {
return [](){};
} else if (param.func_name == "__copy") {
// Perform cross device data copy.
// Directly copy data from the input to the output.
auto fexec = [arg_ptr]() {
DLTensor* from = static_cast<DLTensor*>(arg_ptr->arg_values[0].v_handle);
DLTensor* to = static_cast<DLTensor*>(arg_ptr->arg_values[1].v_handle);
TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr));
};
return fexec;
}
// get compiled function from module.
// Get compiled function from the module that contains both host and device
// code.
tvm::runtime::PackedFunc pf = module_.GetFunction(param.func_name, false);
CHECK(pf != nullptr) << "no such function in module: " << param.func_name;
auto fexec = [arg_ptr, pf] () {
auto fexec = [arg_ptr, pf]() {
TVMRetValue rv;
TVMArgs targs(arg_ptr->arg_values.data(),
arg_ptr->arg_tcodes.data(),
......@@ -562,7 +620,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
PackedFunc GraphRuntime::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
// return member functions during query.
// Return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) {
......@@ -618,29 +676,53 @@ PackedFunc GraphRuntime::GetFunction(
}
}
Module GraphRuntimeCreate(std::string sym_json,
tvm::runtime::Module m,
int device_type,
int device_id) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
Module GraphRuntimeCreate(const std::string& sym_json,
const tvm::runtime::Module& m,
const std::vector<TVMContext>& ctxs) {
std::shared_ptr<GraphRuntime> exec = std::make_shared<GraphRuntime>();
exec->Init(sym_json, m, ctx);
exec->Init(sym_json, m, ctxs);
return Module(exec);
}
// Get all context for the host and other runtime devices.
std::vector<TVMContext> GetAllContext(const TVMArgs& args) {
// Reserve the first item as the fallback device.
std::vector<TVMContext> ret;
TVMContext ctx;
for (int i = 2; i < args.num_args; i += 2) {
int dev_type = args[i];
ctx.device_type = static_cast<DLDeviceType>(dev_type);
ctx.device_id = args[i + 1];
ret.push_back(ctx);
}
return ret;
}
// 4-argument version is currently reserved to keep support of calling
// from tvm4j and javascript, since they don't have heterogeneous
// execution support yet. For heterogenenous execution, at least 5 arguments will
// be passed in. The third one is the number of devices.
// Eventually, we will only probably pass TVMContext for all the languages.
TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = GraphRuntimeCreate(args[0], args[1], args[2], args[3]);
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4)
<< "The expected number of arguments for graph_runtime.create is "
"at least 4, but it has "
<< args.num_args;
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate(args[0], args[1], contexts);
});
TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<< args.num_args;
void* mhandle = args[1];
*rv = GraphRuntimeCreate(args[0],
*static_cast<tvm::runtime::Module*>(mhandle),
args[2], args[3]);
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate(
args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts);
});
} // namespace runtime
} // namespace tvm
# pylint: disable=too-many-locals
"""Unit tests for heterogeneous runtime"""
import json
import numpy as np
import tvm
from tvm.contrib import graph_runtime, util
import topi
def get_simplex_graph(host_dev_type, device_dev_type):
r""" Return the hand-crafted json object where only one copy node is
inserted. This node copies data from the target device to cpu.
The network is constructed as following:
A B
\ /
elemwise_add (gpu)
\
copy C
\ /
elemwise_sub (cpu)
Parameters
----------
host_dev_type : int
The device type of the host processor, e.g. cpu.
device_dev_type : int
The device type of the device processor, e.g. gpu, opencl, etc.
Returns
-------
json : json
A json encoded object.
"""
# Construct each node in the graph.
var_a = {"op": "null", "name": "A", "inputs": []}
var_b = {"op": "null", "name": "B", "inputs": []}
elemwise_add = {
"op": "tvm_op", "name": "elemwise_add",
"attrs": {
"flatten_data": "1",
"func_name": "elemwise_add",
"num_inputs": "2",
"num_outputs": "1"
},
"inputs": [[0, 0, 0], [1, 0, 0]]
}
copy = {
"op": "tvm_op",
"name": "__copy_add_to_sub",
"attrs": {
"flatten_data": "0",
"func_name": "__copy",
"num_inputs": "1",
"num_outputs": "1"
},
"inputs": [[2, 0, 0]]
}
var_c = {"op": "null", "name": "C", "inputs": []}
elemwise_sub = {
"op": "tvm_op", "name": "elemwise_sub",
"attrs": {
"flatten_data": "0",
"func_name": "elemwise_sub",
"num_inputs": "2",
"num_outputs": "1"
},
"inputs": [[3, 0, 0], [4, 0, 0]]
}
# Group the nodes.
nodes = [var_a, var_b, elemwise_add, copy, var_c, elemwise_sub]
arg_nodes = [0, 1, 4]
node_row_ptr = [0, 1, 2, 3, 4, 5, 6]
heads = [[5, 0, 0]]
shape = (4,)
attrs = {
"storage_id": ["list_int", [3, 4, 0, 1, 5, 2]],
"shape": ["list_shape", [shape, shape, shape, shape, shape, shape]],
"device_index": ["list_int", [device_dev_type, device_dev_type,
device_dev_type, host_dev_type,
host_dev_type, host_dev_type]],
"dtype": ["list_int", [0, 0, 0, 0, 0, 0]],
"dltype": ["list_str", ["float32", "float32", "float32",
"float32", "float32", "float32"]]
}
# Construct the graph.
graph = {"nodes": nodes,
"arg_nodes": arg_nodes,
"node_row_ptr": node_row_ptr,
"heads": heads,
"attrs": attrs}
return json.dumps(graph)
def test_simplex_data_transferring():
r"""
Test the heterogeneous execution of a simple network where data
transferring is from the target device to the host processor at runtime.
The host processor is always assumed to be cpu, and the device varies.
"""
host = "cpu"
target_host = "llvm"
host_ctx = tvm.context(host)
if not tvm.module.enabled(target_host):
print("Skip test because llvm is not enabled.")
return
def check_device(device, target_device):
if not tvm.module.enabled(target_device):
print("Skip test because {} is not enabled.".format(target_device))
return
device_ctx = tvm.context(device)
graph = get_simplex_graph(host_ctx.device_type, device_ctx.device_type)
shape = (4,)
# Create module for add whose target is the device.
tensor_a = tvm.placeholder(shape, name="A")
tensor_b = tvm.placeholder(shape, name="B")
elemwise_add = tvm.compute(shape, lambda *i: tensor_a(*i)
+ tensor_b(*i), name="elemwise_add")
target = topi.cpp.TEST_create_target(device)
schedule_add = topi.cpp.cuda.schedule_injective(target, [elemwise_add])
lower_add = tvm.lower(schedule_add, [tensor_a, tensor_b, elemwise_add],
name="elemwise_add")
host_funcs_add, lib_add = tvm.build(lower_add, target=target_device,
name="elemwise_add",
postpone_host_codegen=True)
# Insert copy. Neither compute nor schedule is required for the copy
# node. The compute will be performed at runtime which is just data
# copy from the input to the output.
tensor_copy = tvm.placeholder(shape, name="__copy")
# Create module for sub whose target is the host.
tensor_c = tvm.placeholder(shape, name="C")
elemwise_sub = tvm.compute(shape, lambda *i: tensor_copy(*i)
- tensor_c(*i), name="elemwise_sub")
schedule_sub = tvm.create_schedule(elemwise_sub.op)
lower_sub = tvm.lower(schedule_sub, [tensor_copy, tensor_c,
elemwise_sub],
name="elemwise_sub")
host_funcs_sub, lib_sub = tvm.build(lower_sub, target=target_host,
name="elemwise_sub",
postpone_host_codegen=True)
host_funcs = host_funcs_add + host_funcs_sub
mhost = tvm.codegen.build_module(host_funcs, target_host)
if lib_add:
mhost.import_module(lib_add)
if lib_sub:
mhost.import_module(lib_sub)
ctx = [host_ctx, device_ctx]
mod = graph_runtime.create(graph, mhost, ctx)
params = {}
params["A"] = tensor_a = np.random.uniform(
size=shape).astype(tensor_a.dtype)
params["B"] = tensor_b = np.random.uniform(
size=shape).astype(tensor_b.dtype)
params["C"] = tensor_c = np.random.uniform(
size=shape).astype(tensor_c.dtype)
mod.set_input(**params)
mod.run()
out = mod.get_output(0, tvm.nd.empty(shape))
np.testing.assert_equal(
out.asnumpy(), (tensor_a + tensor_b) - tensor_c)
dev_tar = {"cuda": "cuda", "opencl": "opencl"}
for device, target in dev_tar.items():
check_device(device, target)
def get_duplex_graph(host_dev_type, device_dev_type):
r""" Return the hand-crafted json object where two copy nodes are inserted.
Data transferring happens back-and-forth between the target device and CPU.
The network is constructed as following:
A B
\ /
elemwise_add (gpu)
\
copy C
\ /
elemwise_sub (cpu)
\
copy D
\ /
elemwise_add (gpu)
Parameters
----------
host_dev_type : int
The device type of the host processor, e.g. cpu.
device_dev_type : int
The device type of the device processor, e.g. gpu, opencl, etc.
Returns
-------
json : json
A json encoded object.
"""
# Construct each node in the graph.
var_a = {"op": "null", "name": "A", "inputs": []}
var_b = {"op": "null", "name": "B", "inputs": []}
elemwise_add0 = {
"op": "tvm_op", "name": "elemwise_add0",
"attrs": {
"flatten_data": "1",
"func_name": "elemwise_add0",
"num_inputs": "2",
"num_outputs": "1"
},
"inputs": [[0, 0, 0], [1, 0, 0]]
}
copy_add_sub = {
"op": "tvm_op",
"name": "__copy_add_to_sub",
"attrs": {
"flatten_data": "0",
"func_name": "__copy",
"num_inputs": "1",
"num_outputs": "1"
},
"inputs": [[2, 0, 0]]
}
var_c = {"op": "null", "name": "C", "inputs": []}
elemwise_sub = {
"op": "tvm_op", "name": "elemwise_sub",
"attrs": {
"flatten_data": "0",
"func_name": "elemwise_sub",
"num_inputs": "2",
"num_outputs": "1"
},
"inputs": [[3, 0, 0], [4, 0, 0]]
}
copy_sub_add = {
"op": "tvm_op",
"name": "__copy_sub_to_add",
"attrs": {
"flatten_data": "0",
"func_name": "__copy",
"num_inputs": "1",
"num_outputs": "1"
},
"inputs": [[5, 0, 0]]
}
var_d = {"op": "null", "name": "D", "inputs": []}
elemwise_add1 = {
"op": "tvm_op", "name": "elemwise_add1",
"attrs": {
"flatten_data": "0",
"func_name": "elemwise_add1",
"num_inputs": "2",
"num_outputs": "1"
},
"inputs": [[6, 0, 0], [7, 0, 0]]
}
# Group the nodes.
nodes = [var_a, var_b, elemwise_add0, copy_add_sub, var_c, elemwise_sub,
copy_sub_add, var_d, elemwise_add1]
arg_nodes = [0, 1, 4, 7]
node_row_ptr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
heads = [[8, 0, 0]]
shape = (4,)
attrs = {
"storage_id": ["list_int", [4, 5, 0, 1, 6, 2, 0, 7, 3]],
"shape": ["list_shape", [shape, shape, shape, shape, shape, shape,
shape, shape, shape]],
"device_index": ["list_int", [device_dev_type, device_dev_type,
device_dev_type,
host_dev_type, host_dev_type, host_dev_type,
device_dev_type, device_dev_type,
device_dev_type]],
"dtype": ["list_int", [0, 0, 0, 0, 0, 0, 0, 0, 0]],
"dltype": ["list_str", ["float32", "float32", "float32",
"float32", "float32", "float32",
"float32", "float32", "float32"]]
}
# Construct the graph.
graph = {"nodes": nodes,
"arg_nodes": arg_nodes,
"node_row_ptr": node_row_ptr,
"heads": heads,
"attrs": attrs}
return json.dumps(graph)
def test_duplex_data_transferring():
r"""
Test the heterogeneous execution of a simple network where data
transferring occurs back-and-forth between the target device and host
processor.
The host processor is always assumed to be cpu, and the target device
varies.
"""
host = "cpu"
target_host = "llvm"
host_ctx = tvm.context(host)
if not tvm.module.enabled(target_host):
print("Skip test because llvm is not enabled.")
return
def check_device(device, target_device):
if not tvm.module.enabled(target_device):
print("Skip test because {} is not enabled.".format(target_device))
return
device_ctx = tvm.context(device)
graph = get_duplex_graph(host_ctx.device_type, device_ctx.device_type)
shape = (4,)
# Insert copy nodes for data transferring between add and sub nodes.
# Transfers data from gpu to cpu.
copy_add_sub = tvm.placeholder(shape, name="__copy0")
# Transfers data from cpu to gpu.
copy_sub_add = tvm.placeholder(shape, name="__copy1")
# Create a module containing adds on the device.
tensor_a = tvm.placeholder(shape, name="A")
tensor_b = tvm.placeholder(shape, name="B")
tensor_d = tvm.placeholder(shape, name="D")
elemwise_add0 = tvm.compute(shape, lambda *i: tensor_a(*i)
+ tensor_b(*i), name="elemwise_add0")
elemwise_add1 = tvm.compute(shape, lambda *i: copy_sub_add(*i)
+ tensor_d(*i), name="elemwise_add1")
target = topi.cpp.TEST_create_target(device)
add_schedule0 = topi.cpp.cuda.schedule_injective(
target, [elemwise_add0])
lower_add0 = tvm.lower(
add_schedule0, [tensor_a, tensor_b, elemwise_add0],
name="elemwise_add0")
add_schedule1 = topi.cpp.cuda.schedule_injective(
target, [elemwise_add1])
lower_add1 = tvm.lower(
add_schedule1, [tensor_d, copy_sub_add, elemwise_add1],
name="elemwise_add1")
host_funcs_add, lib_add = tvm.build([lower_add0, lower_add1],
target=target_device,
postpone_host_codegen=True)
# Create module for sub whose target is the host.
tensor_c = tvm.placeholder(shape, name="C")
elemwise_sub = tvm.compute(shape, lambda *i: copy_add_sub(*i)
- tensor_c(*i), name="elemwise_sub")
sub_schedule = tvm.create_schedule(elemwise_sub.op)
lower_sub = tvm.lower(sub_schedule, [copy_add_sub, tensor_c,
elemwise_sub],
name="elemwise_sub")
host_funcs_sub, lib_sub = tvm.build(lower_sub, target=target_host,
postpone_host_codegen=True)
host_funcs = host_funcs_add + host_funcs_sub
mhost = tvm.codegen.build_module(host_funcs, target_host)
if lib_add:
mhost.import_module(lib_add)
if lib_sub:
mhost.import_module(lib_sub)
ctx = [host_ctx, device_ctx]
params = {}
params["A"] = tensor_a = np.random.uniform(
size=shape).astype(tensor_a.dtype)
params["B"] = tensor_b = np.random.uniform(
size=shape).astype(tensor_b.dtype)
params["C"] = tensor_c = np.random.uniform(
size=shape).astype(tensor_c.dtype)
params["D"] = tensor_d = np.random.uniform(
size=shape).astype(tensor_d.dtype)
def check_verify():
mod = graph_runtime.create(graph, mhost, ctx)
mod.set_input(**params)
mod.run()
out = mod.get_output(0, tvm.nd.empty(shape))
np.testing.assert_equal(
out.asnumpy(), tensor_a + tensor_b - tensor_c + tensor_d)
def check_load_module():
temp = util.tempdir()
path_lib = temp.relpath("deploy.so")
mhost.export_library(path_lib)
with open(temp.relpath("deploy.json"), "w") as out_file:
out_file.write(graph)
loaded_lib = tvm.module.load(path_lib)
loaded_graph = open(temp.relpath("deploy.json")).read()
mod = graph_runtime.create(loaded_graph, loaded_lib, ctx)
mod.set_input(**params)
mod.run()
out = mod.get_output(0, tvm.nd.empty(shape))
np.testing.assert_equal(
out.asnumpy(), tensor_a + tensor_b - tensor_c + tensor_d)
check_verify()
check_load_module()
dev_tar = {"cuda": "cuda", "opencl": "opencl"}
for device, target in dev_tar.items():
check_device(device, target)
if __name__ == "__main__":
test_simplex_data_transferring()
test_duplex_data_transferring()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment