Commit 7bcb3f53 by Tianqi Chen Committed by ziheng

[REFACTOR] collections->container, RPC returns func, time_evaluator r… (#244)

* [REFACTOR] collections->container, RPC returns func, time_evaluator returns struct

* fix executor
parent c324494f
...@@ -37,7 +37,7 @@ def bind(g, ctx): ...@@ -37,7 +37,7 @@ def bind(g, ctx):
def _lower(sch, inputs, func_name): def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name) f = tvm.lower(sch, inputs, name=func_name)
return f if isinstance( return f if isinstance(
f, (tvm.collections.Array, tuple, list)) else [f] f, (tvm.container.Array, tuple, list)) else [f]
@tvm.register_func("tvm_graph.build_target") @tvm.register_func("tvm_graph.build_target")
......
tvm.collections
---------------
.. automodule:: tvm.collections
:members:
tvm.container
-------------
.. automodule:: tvm.container
:members:
...@@ -9,7 +9,7 @@ from . import stmt ...@@ -9,7 +9,7 @@ from . import stmt
from . import make from . import make
from . import ir_pass from . import ir_pass
from . import codegen from . import codegen
from . import collections from . import container
from . import schedule from . import schedule
from . import module from . import module
from . import node from . import node
......
...@@ -15,7 +15,7 @@ from . import make as _make ...@@ -15,7 +15,7 @@ from . import make as _make
from . import expr as _expr from . import expr as _expr
from . import tensor as _tensor from . import tensor as _tensor
from . import schedule as _schedule from . import schedule as _schedule
from . import collections as _collections from . import container as _container
from . import tag as _tag from . import tag as _tag
int32 = "int32" int32 = "int32"
...@@ -493,7 +493,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''): ...@@ -493,7 +493,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
raise TypeError("need to be list of ranges") raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1]) dom = Range(dom[0], dom[1])
if not isinstance(dom, _collections.Range): if not isinstance(dom, _container.Range):
raise TypeError("dom need to be Range") raise TypeError("dom need to be Range")
name = name if name else 'iter' name = name if name else 'iter'
v = var(name) v = var(name)
...@@ -628,7 +628,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"): ...@@ -628,7 +628,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
code = fcombine.__code__ code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2 assert fcombine.__code__.co_argcount == 2
expr = convert(expr) expr = convert(expr)
if isinstance(expr, _collections.Array): if isinstance(expr, _container.Array):
size = len(expr) size = len(expr)
larr = [] larr = []
rarr = [] rarr = []
......
...@@ -9,7 +9,7 @@ from . import tensor ...@@ -9,7 +9,7 @@ from . import tensor
from . import schedule from . import schedule
from . import expr from . import expr
from . import ir_pass from . import ir_pass
from . import collections from . import container
from . import module from . import module
from . import codegen from . import codegen
from . import ndarray from . import ndarray
...@@ -276,19 +276,19 @@ def build(sch, ...@@ -276,19 +276,19 @@ def build(sch,
flist = lower(sch, args, flist = lower(sch, args,
name=name, name=name,
binds=binds) binds=binds)
if isinstance(flist, collections.LoweredFunc): if isinstance(flist, container.LoweredFunc):
flist = [flist] flist = [flist]
elif isinstance(sch, collections.LoweredFunc): elif isinstance(sch, container.LoweredFunc):
if args: if args:
raise ValueError("args must be done when build from LoweredFunc") raise ValueError("args must be done when build from LoweredFunc")
flist = [sch] flist = [sch]
elif isinstance(sch, (list, tuple, collections.Array)): elif isinstance(sch, (list, tuple, container.Array)):
flist = sch flist = sch
else: else:
raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc") raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc")
fname_set = set() fname_set = set()
for x in flist: for x in flist:
if not isinstance(x, collections.LoweredFunc): if not isinstance(x, container.LoweredFunc):
raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc") raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc")
if x.name in fname_set: if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name) raise ValueError("Duplicate function name %s" % x.name)
...@@ -296,7 +296,7 @@ def build(sch, ...@@ -296,7 +296,7 @@ def build(sch,
fhost = [] fhost = []
fdevice = [] fdevice = []
for func in flist: for func in flist:
if func.func_type == collections.LoweredFunc.MixedFunc: if func.func_type == container.LoweredFunc.MixedFunc:
if BuildConfig.current.detect_global_barrier: if BuildConfig.current.detect_global_barrier:
func = ir_pass.StorageSync(func, "global") func = ir_pass.StorageSync(func, "global")
func = ir_pass.StorageSync(func, "shared") func = ir_pass.StorageSync(func, "shared")
...@@ -306,9 +306,9 @@ def build(sch, ...@@ -306,9 +306,9 @@ def build(sch,
fhost.append(fsplits[0]) fhost.append(fsplits[0])
for x in fsplits[1:]: for x in fsplits[1:]:
fdevice.append(x) fdevice.append(x)
elif func.func_type == collections.LoweredFunc.HostFunc: elif func.func_type == container.LoweredFunc.HostFunc:
fhost.append(func) fhost.append(func)
elif func.func_type == collections.LoweredFunc.DeviceFunc: elif func.func_type == container.LoweredFunc.DeviceFunc:
fdevice.append(func) fdevice.append(func)
else: else:
raise ValueError("unknown function type %d" % func.func_type) raise ValueError("unknown function type %d" % func.func_type)
......
"""Collections contains data structures used in TVM DSL.""" """Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node from ._ffi.node import NodeBase, register_node
from . import _api_internal from . import _api_internal
......
...@@ -6,7 +6,7 @@ from . import stmt as _stmt ...@@ -6,7 +6,7 @@ from . import stmt as _stmt
from . import expr as _expr from . import expr as _expr
from . import make as _make from . import make as _make
from . import ir_pass as _pass from . import ir_pass as _pass
from . import collections as _collections from . import container as _container
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.node import NodeGeneric from ._ffi.node import NodeGeneric
from .expr import Call as _Call from .expr import Call as _Call
...@@ -280,7 +280,7 @@ class IRBuilder(object): ...@@ -280,7 +280,7 @@ class IRBuilder(object):
The buffer var representing the buffer. The buffer var representing the buffer.
""" """
buffer_var = _api.var(name, dtype="handle") buffer_var = _api.var(name, dtype="handle")
if not isinstance(shape, (list, tuple, _collections.Array)): if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape] shape = [shape]
if scope: if scope:
self.scope_attr(buffer_var, "storage_scope", scope) self.scope_attr(buffer_var, "storage_scope", scope)
......
"""Container of compiled functions of TVM.""" """Container of compiled functions of TVM."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple
from ._ffi.function import ModuleBase, _set_class_module from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .contrib import cc_compiler as _cc, util as _util from .contrib import cc_compiler as _cc, util as _util
ProfileResult = namedtuple("ProfileResult", ["mean"])
class Module(ModuleBase): class Module(ModuleBase):
"""Module container of all TVM generated functions""" """Module container of all TVM generated functions"""
def __repr__(self): def __repr__(self):
...@@ -120,8 +124,14 @@ class Module(ModuleBase): ...@@ -120,8 +124,14 @@ class Module(ModuleBase):
and return a float representing seconds per function call. and return a float representing seconds per function call.
""" """
try: try:
return _RPCTimeEvaluator( feval = _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number) self, func_name, ctx.device_type, ctx.device_id, number)
def evaluator(*args):
"""Internal wrapped evaluator."""
# Wrap feval so we can add more stats in future.
mean = feval(*args)
return ProfileResult(mean=mean)
return evaluator
except NameError: except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled") raise NameError("time_evaluate is only supported when RPC is enabled")
......
...@@ -4,7 +4,7 @@ from ._ffi.node import NodeBase, register_node ...@@ -4,7 +4,7 @@ from ._ffi.node import NodeBase, register_node
from . import _api_internal from . import _api_internal
from . import tensor as _tensor from . import tensor as _tensor
from . import expr as _expr from . import expr as _expr
from . import collections as _collections from . import container as _container
from ._ffi.function import _init_api from ._ffi.function import _init_api
@register_node @register_node
...@@ -74,7 +74,7 @@ def create_schedule(ops): ...@@ -74,7 +74,7 @@ def create_schedule(ops):
sch : schedule.Schedule sch : schedule.Schedule
The created schedule. The created schedule.
""" """
if not isinstance(ops, (list, _collections.Array)): if not isinstance(ops, (list, _container.Array)):
ops = [ops] ops = [ops]
return _api_internal._CreateSchedule(ops) return _api_internal._CreateSchedule(ops)
......
...@@ -13,16 +13,28 @@ namespace runtime { ...@@ -13,16 +13,28 @@ namespace runtime {
// Wrapped remote function to packed func. // Wrapped remote function to packed func.
struct RPCWrappedFunc { struct RPCWrappedFunc {
public: public:
RPCWrappedFunc(void* handle, std::shared_ptr<RPCSession> sess) RPCWrappedFunc(void* handle,
: handle_(handle), sess_(sess) {} std::shared_ptr<RPCSession> sess)
: handle_(handle), sess_(sess) {
fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
WrapRemote(sess, args.values[0].v_handle, args.type_codes[0], rv);
});
}
void operator()(TVMArgs args, TVMRetValue *rv) const { void operator()(TVMArgs args, TVMRetValue *rv) const {
sess_->CallFunc(handle_, args, rv); sess_->CallFunc(handle_, args, rv, &fwrap_);
} }
~RPCWrappedFunc() { ~RPCWrappedFunc() {
sess_->CallRemote(RPCCode::kFreeFunc, handle_); sess_->CallRemote(RPCCode::kFreeFunc, handle_);
} }
static void WrapRemote(std::shared_ptr<RPCSession> sess,
void* handle,
int tcode,
TVMRetValue* rv);
private: private:
PackedFunc fwrap_;
void* handle_{nullptr}; void* handle_{nullptr};
std::shared_ptr<RPCSession> sess_; std::shared_ptr<RPCSession> sess_;
}; };
...@@ -94,8 +106,28 @@ class RPCModuleNode final : public ModuleNode { ...@@ -94,8 +106,28 @@ class RPCModuleNode final : public ModuleNode {
void* module_handle_{nullptr}; void* module_handle_{nullptr};
// The local channel // The local channel
std::shared_ptr<RPCSession> sess_; std::shared_ptr<RPCSession> sess_;
// Wrap function to wrap remote module/function.
PackedFunc fwrap_;
}; };
void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
void* handle,
int tcode,
TVMRetValue *rv) {
if (handle == nullptr) return;
if (tcode == kFuncHandle) {
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess);
*rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv);
});
} else {
CHECK_EQ(tcode, kModuleHandle);
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(handle, sess);
*rv = Module(n);
}
}
Module CreateRPCModule(std::shared_ptr<RPCSession> sess) { Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
std::shared_ptr<RPCModuleNode> n = std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(nullptr, sess); std::make_shared<RPCModuleNode>(nullptr, sess);
......
...@@ -69,7 +69,10 @@ class RPCSession::EventHandler { ...@@ -69,7 +69,10 @@ class RPCSession::EventHandler {
void FinishCopyAck() { void FinishCopyAck() {
this->SwitchToState(kRecvCode); this->SwitchToState(kRecvCode);
} }
RPCCode HandleNextEvent(TVMRetValue* rv) { RPCCode HandleNextEvent(TVMRetValue* rv,
bool client_mode,
const PackedFunc* fwrap) {
std::swap(client_mode_, client_mode);
while (this->Ready()) { while (this->Ready()) {
switch (state_) { switch (state_) {
case kRecvCode: HandleRecvCode(); break; case kRecvCode: HandleRecvCode(); break;
...@@ -110,19 +113,29 @@ class RPCSession::EventHandler { ...@@ -110,19 +113,29 @@ class RPCSession::EventHandler {
case kReturnReceived: { case kReturnReceived: {
CHECK_EQ(arg_buf_->value.size(), 1U); CHECK_EQ(arg_buf_->value.size(), 1U);
TVMArgValue argv = arg_buf_->AsTVMArgs()[0]; TVMArgValue argv = arg_buf_->AsTVMArgs()[0];
if (argv.type_code() == kFuncHandle ||
argv.type_code() == kModuleHandle) {
CHECK(fwrap != nullptr) << "function/module wrapper not available";
fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv);
} else {
*rv = argv; *rv = argv;
}
arg_buf_.reset(); arg_buf_.reset();
this->SwitchToState(kRecvCode); this->SwitchToState(kRecvCode);
std::swap(client_mode_, client_mode);
return RPCCode::kReturn; return RPCCode::kReturn;
} }
case kCopyAckReceived: { case kCopyAckReceived: {
std::swap(client_mode_, client_mode);
return RPCCode::kCopyAck; return RPCCode::kCopyAck;
} }
case kShutdownReceived: { case kShutdownReceived: {
std::swap(client_mode_, client_mode);
return RPCCode::kShutdown; return RPCCode::kShutdown;
} }
} }
} }
std::swap(client_mode_, client_mode);
return RPCCode::kNone; return RPCCode::kNone;
} }
// Reset and clear all states. // Reset and clear all states.
...@@ -161,6 +174,8 @@ class RPCSession::EventHandler { ...@@ -161,6 +174,8 @@ class RPCSession::EventHandler {
writer_->Write(&value, sizeof(TVMValue)); writer_->Write(&value, sizeof(TVMValue));
break; break;
} }
case kFuncHandle:
case kModuleHandle:
case kHandle: { case kHandle: {
// always send handle in 64 bit. // always send handle in 64 bit.
uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle); uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
...@@ -231,6 +246,8 @@ class RPCSession::EventHandler { ...@@ -231,6 +246,8 @@ class RPCSession::EventHandler {
int arg_index_; int arg_index_;
// The stage of each argument receiver. // The stage of each argument receiver.
int arg_recv_stage_; int arg_recv_stage_;
// Whether current handler is client or server mode.
bool client_mode_{false};
// Argument buffer // Argument buffer
std::unique_ptr<RPCArgBuffer> arg_buf_; std::unique_ptr<RPCArgBuffer> arg_buf_;
// Temp byte buffer. // Temp byte buffer.
...@@ -305,7 +322,15 @@ class RPCSession::EventHandler { ...@@ -305,7 +322,15 @@ class RPCSession::EventHandler {
case kHandle: case kHandle:
case kStr: case kStr:
case kBytes: case kBytes:
case kTVMContext: this->RequestBytes(sizeof(TVMValue)); break; case kTVMContext: {
this->RequestBytes(sizeof(TVMValue)); break;
}
case kFuncHandle:
case kModuleHandle: {
CHECK(client_mode_)
<< "Only client can receive remote functions";
this->RequestBytes(sizeof(TVMValue)); break;
}
case kNull: break; case kNull: break;
case kArrayHandle: { case kArrayHandle: {
this->RequestBytes(sizeof(uint64_t)); this->RequestBytes(sizeof(uint64_t));
...@@ -337,6 +362,8 @@ class RPCSession::EventHandler { ...@@ -337,6 +362,8 @@ class RPCSession::EventHandler {
this->SwitchToState(kRecvPackedSeqArg); this->SwitchToState(kRecvPackedSeqArg);
break; break;
} }
case kFuncHandle:
case kModuleHandle:
case kHandle: { case kHandle: {
// always send handle in 64 bit. // always send handle in 64 bit.
uint64_t handle; uint64_t handle;
...@@ -558,6 +585,13 @@ class RPCSession::EventHandler { ...@@ -558,6 +585,13 @@ class RPCSession::EventHandler {
ret_value.v_handle = &arr; ret_value.v_handle = &arr;
ret_tcode = kBytes; ret_tcode = kBytes;
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1);
} else if (rv.type_code() == kFuncHandle ||
rv.type_code() == kModuleHandle) {
// always send handle in 64 bit.
CHECK(!client_mode_)
<< "Only server can send function and module handle back.";
rv.MoveToCHost(&ret_value, &ret_tcode);
SendPackedSeq(&ret_value, &ret_tcode, 1);
} else { } else {
ret_value = rv.value(); ret_value = rv.value();
ret_tcode = rv.type_code(); ret_tcode = rv.type_code();
...@@ -634,7 +668,8 @@ struct RPCSessTable { ...@@ -634,7 +668,8 @@ struct RPCSessTable {
std::array<std::weak_ptr<RPCSession>, kMaxRPCSession> tbl_; std::array<std::weak_ptr<RPCSession>, kMaxRPCSession> tbl_;
}; };
RPCCode RPCSession::HandleUntilReturnEvent(TVMRetValue* rv) { RPCCode RPCSession::HandleUntilReturnEvent(
TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap) {
RPCCode code = RPCCode::kCallFunc; RPCCode code = RPCCode::kCallFunc;
while (code != RPCCode::kReturn && while (code != RPCCode::kReturn &&
code != RPCCode::kShutdown && code != RPCCode::kShutdown &&
...@@ -657,7 +692,7 @@ RPCCode RPCSession::HandleUntilReturnEvent(TVMRetValue* rv) { ...@@ -657,7 +692,7 @@ RPCCode RPCSession::HandleUntilReturnEvent(TVMRetValue* rv) {
} }
} }
} }
code = handler_->HandleNextEvent(rv); code = handler_->HandleNextEvent(rv, client_mode, fwrap);
} }
return code; return code;
} }
...@@ -668,7 +703,7 @@ void RPCSession::Init() { ...@@ -668,7 +703,7 @@ void RPCSession::Init() {
// Quick function to call remote. // Quick function to call remote.
call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) { call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args); handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
RPCCode code = HandleUntilReturnEvent(rv); RPCCode code = HandleUntilReturnEvent(rv, true, nullptr);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code); CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
}); });
} }
...@@ -712,7 +747,7 @@ void RPCSession::Shutdown() { ...@@ -712,7 +747,7 @@ void RPCSession::Shutdown() {
void RPCSession::ServerLoop() { void RPCSession::ServerLoop() {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
TVMRetValue rv; TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv) == RPCCode::kShutdown); CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown);
LOG(INFO) << "Shutdown..."; LOG(INFO) << "Shutdown...";
channel_.reset(nullptr); channel_.reset(nullptr);
} }
...@@ -722,7 +757,7 @@ bool RPCSession::ServerOnMessageHandler(const std::string& bytes) { ...@@ -722,7 +757,7 @@ bool RPCSession::ServerOnMessageHandler(const std::string& bytes) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
reader_.Write(bytes.c_str(), bytes.length()); reader_.Write(bytes.c_str(), bytes.length());
TVMRetValue rv; TVMRetValue rv;
RPCCode code = handler_->HandleNextEvent(&rv); RPCCode code = handler_->HandleNextEvent(&rv, false, nullptr);
while (writer_.bytes_available() != 0) { while (writer_.bytes_available() != 0) {
writer_.ReadWithCallback([this](const void *data, size_t size) { writer_.ReadWithCallback([this](const void *data, size_t size) {
return channel_->Send(data, size); return channel_->Send(data, size);
...@@ -733,13 +768,18 @@ bool RPCSession::ServerOnMessageHandler(const std::string& bytes) { ...@@ -733,13 +768,18 @@ bool RPCSession::ServerOnMessageHandler(const std::string& bytes) {
} }
// Get remote function with name // Get remote function with name
void RPCSession::CallFunc(void* h, TVMArgs args, TVMRetValue* rv) { void RPCSession::CallFunc(void* h,
TVMArgs args,
TVMRetValue* rv,
const PackedFunc* fwrap) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
RPCCode code = RPCCode::kCallFunc; RPCCode code = RPCCode::kCallFunc;
writer_.Write(&code, sizeof(code)); writer_.Write(&code, sizeof(code));
uint64_t handle = reinterpret_cast<uint64_t>(h); uint64_t handle = reinterpret_cast<uint64_t>(h);
writer_.Write(&handle, sizeof(handle)); writer_.Write(&handle, sizeof(handle));
call_remote_.CallPacked(args, rv); handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
code = HandleUntilReturnEvent(rv, true, fwrap);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
} }
void RPCSession::CopyToRemote(void* from, void RPCSession::CopyToRemote(void* from,
...@@ -761,7 +801,7 @@ void RPCSession::CopyToRemote(void* from, ...@@ -761,7 +801,7 @@ void RPCSession::CopyToRemote(void* from,
writer_.Write(&ctx_to, sizeof(ctx_to)); writer_.Write(&ctx_to, sizeof(ctx_to));
writer_.Write(reinterpret_cast<char*>(from) + from_offset, data_size); writer_.Write(reinterpret_cast<char*>(from) + from_offset, data_size);
TVMRetValue rv; TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv) == RPCCode::kReturn); CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn);
} }
void RPCSession::CopyFromRemote(void* from, void RPCSession::CopyFromRemote(void* from,
...@@ -782,7 +822,7 @@ void RPCSession::CopyFromRemote(void* from, ...@@ -782,7 +822,7 @@ void RPCSession::CopyFromRemote(void* from,
writer_.Write(&size, sizeof(size)); writer_.Write(&size, sizeof(size));
writer_.Write(&ctx_from, sizeof(ctx_from)); writer_.Write(&ctx_from, sizeof(ctx_from));
TVMRetValue rv; TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv) == RPCCode::kCopyAck); CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck);
reader_.Reserve(data_size); reader_.Reserve(data_size);
while (reader_.bytes_available() < data_size) { while (reader_.bytes_available() < data_size) {
size_t bytes_needed = data_size - reader_.bytes_available(); size_t bytes_needed = data_size - reader_.bytes_available();
......
...@@ -95,13 +95,16 @@ class RPCSession { ...@@ -95,13 +95,16 @@ class RPCSession {
bool ServerOnMessageHandler(const std::string& bytes); bool ServerOnMessageHandler(const std::string& bytes);
/*! /*!
* \brief Call into remote function * \brief Call into remote function
* \param sptr_to_self shared_ptr to self.
* \param handle The function handle * \param handle The function handle
* \param args The arguments * \param args The arguments
* \param rv The return value. * \param rv The return value.
* \param fwrapper Wrapper function to turn Function/Module handle into real return.
*/ */
void CallFunc(RPCFuncHandle handle, void CallFunc(RPCFuncHandle handle,
TVMArgs args, TVMArgs args,
TVMRetValue* rv); TVMRetValue* rv,
const PackedFunc* fwrap);
/*! /*!
* \brief Copy bytes into remote array content. * \brief Copy bytes into remote array content.
* \param from The source host data. * \param from The source host data.
...@@ -146,6 +149,7 @@ class RPCSession { ...@@ -146,6 +149,7 @@ class RPCSession {
int nstep); int nstep);
/*! /*!
* \brief Call a remote defined system function with arguments. * \brief Call a remote defined system function with arguments.
* \param sptr_to_self shared_ptr to self.
* \param fcode The function code. * \param fcode The function code.
* \param args The arguments * \param args The arguments
* \return The returned remote value. * \return The returned remote value.
...@@ -178,7 +182,8 @@ class RPCSession { ...@@ -178,7 +182,8 @@ class RPCSession {
class EventHandler; class EventHandler;
// Handle events until receives a return // Handle events until receives a return
// Also flushes channels so that the function advances. // Also flushes channels so that the function advances.
RPCCode HandleUntilReturnEvent(TVMRetValue* rv); RPCCode HandleUntilReturnEvent(
TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap);
// Initalization // Initalization
void Init(); void Init();
// Shutdown // Shutdown
...@@ -191,7 +196,7 @@ class RPCSession { ...@@ -191,7 +196,7 @@ class RPCSession {
common::RingBuffer reader_, writer_; common::RingBuffer reader_, writer_;
// Event handler. // Event handler.
std::shared_ptr<EventHandler> handler_; std::shared_ptr<EventHandler> handler_;
// call remote with the specified function coede. // call remote with specified function code.
PackedFunc call_remote_; PackedFunc call_remote_;
// The index of this session in RPC session table. // The index of this session in RPC session table.
int table_index_{0}; int table_index_{0};
......
...@@ -95,8 +95,8 @@ def test_add(): ...@@ -95,8 +95,8 @@ def test_add():
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
vbias = np.random.uniform() vbias = np.random.uniform()
vscale = np.random.uniform() vscale = np.random.uniform()
ftimer = fadd.time_evaluator(fadd.entry_name, ctx, number=1000) ftimer = fadd.time_evaluator(fadd.entry_name, ctx, number=10)
tcost = ftimer(a, b, c, vbias, vscale) tcost = ftimer(a, b, c, vbias, vscale).mean
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6) c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6)
......
...@@ -78,8 +78,8 @@ def test_gemm(): ...@@ -78,8 +78,8 @@ def test_gemm():
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
ftimer = f.time_evaluator(f.entry_name, ctx, number=20) ftimer = f.time_evaluator(f.entry_name, ctx, number=1)
tcost = ftimer(a, b, c) tcost = ftimer(a, b, c).mean
print("%s: exec=%g sec/op" % (ctx, tcost)) print("%s: exec=%g sec/op" % (ctx, tcost))
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
......
...@@ -11,7 +11,7 @@ def test_flatten2(): ...@@ -11,7 +11,7 @@ def test_flatten2():
xo, xi = s[A2].split(A2.op.axis[0], 8) xo, xi = s[A2].split(A2.op.axis[0], 8)
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt) print(stmt)
......
...@@ -11,7 +11,7 @@ def test_storage_share(): ...@@ -11,7 +11,7 @@ def test_storage_share():
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
...@@ -47,7 +47,7 @@ def test_storage_share_gpu(): ...@@ -47,7 +47,7 @@ def test_storage_share_gpu():
s[A[2*t+1]].set_scope("shared") s[A[2*t+1]].set_scope("shared")
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A[0].shape, A[0].dtype, name='A') Ab = tvm.decl_buffer(A[0].shape, A[0].dtype, name='A')
Bb = tvm.decl_buffer(A[0].shape, A[0].dtype, name='B') Bb = tvm.decl_buffer(A[0].shape, A[0].dtype, name='B')
......
...@@ -15,7 +15,7 @@ def test_storage_sync(): ...@@ -15,7 +15,7 @@ def test_storage_sync():
s[A1].set_scope("shared") s[A1].set_scope("shared")
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2') A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
......
...@@ -14,7 +14,7 @@ def test_virtual_thread(): ...@@ -14,7 +14,7 @@ def test_virtual_thread():
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
......
...@@ -11,7 +11,7 @@ def test_rpc_simple(): ...@@ -11,7 +11,7 @@ def test_rpc_simple():
def addone(x): def addone(x):
return x + 1 return x + 1
@tvm.register_func("rpc.test.strcat") @tvm.register_func("rpc.test.strcat")
def addone(name, x): def strcat(name, x):
return "%s:%d" % (name, x) return "%s:%d" % (name, x)
@tvm.register_func("rpc.test.except") @tvm.register_func("rpc.test.except")
...@@ -83,16 +83,26 @@ def test_rpc_remote_module(): ...@@ -83,16 +83,26 @@ def test_rpc_remote_module():
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
cost = time_f(a, b) cost = time_f(a, b).mean
print('%g secs/op' % cost) print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote() check_remote()
def test_rpc_return_func():
@tvm.register_func("rpc.test.remote_func")
def addone(x):
return lambda y: x+y
server = rpc.Server("localhost")
client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.remote_func")
fadd = f1(10)
assert fadd(12) == 22
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
test_rpc_return_func()
test_rpc_file_exchange() test_rpc_file_exchange()
exit(0)
test_rpc_array() test_rpc_array()
test_rpc_remote_module() test_rpc_remote_module()
test_rpc_simple() test_rpc_simple()
...@@ -11,7 +11,7 @@ def test_bound1(): ...@@ -11,7 +11,7 @@ def test_bound1():
xo, xi = s[A2].split(s[A2].op.axis[0], 8) xo, xi = s[A2].split(s[A2].op.axis[0], 8)
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
assert(bounds[A1.op.axis[0]].extent.value == 8) assert(bounds[A1.op.axis[0]].extent.value == 8)
def test_bound2(): def test_bound2():
...@@ -26,7 +26,7 @@ def test_bound2(): ...@@ -26,7 +26,7 @@ def test_bound2():
_ = s.normalize() _ = s.normalize()
s[A1].compute_at(s[A2], yo) s[A1].compute_at(s[A2], yo)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
assert(bounds[A1.op.axis[0]].extent.value == 8) assert(bounds[A1.op.axis[0]].extent.value == 8)
assert(bounds[A1.op.axis[1]].extent.value == 8) assert(bounds[A1.op.axis[1]].extent.value == 8)
...@@ -49,7 +49,7 @@ def test_bound3(): ...@@ -49,7 +49,7 @@ def test_bound3():
s[A1].compute_at(s[A2], yo) s[A1].compute_at(s[A2], yo)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16) assert(bounds[A1.op.axis[1]].extent.value==16)
......
...@@ -9,7 +9,7 @@ def test_schedule0(): ...@@ -9,7 +9,7 @@ def test_schedule0():
s = tvm.create_schedule(A1.op) s = tvm.create_schedule(A1.op)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule1(): def test_schedule1():
...@@ -21,7 +21,7 @@ def test_schedule1(): ...@@ -21,7 +21,7 @@ def test_schedule1():
s = tvm.create_schedule(A1.op) s = tvm.create_schedule(A1.op)
xo, xi = s[A1].split(A1.op.axis[0], 8) xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
...@@ -36,7 +36,7 @@ def test_schedule2(): ...@@ -36,7 +36,7 @@ def test_schedule2():
xo, xi = s[A2].split(A2.op.axis[0], 8) xo, xi = s[A2].split(A2.op.axis[0], 8)
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
...@@ -79,7 +79,7 @@ def test_schedule_const_bound(): ...@@ -79,7 +79,7 @@ def test_schedule_const_bound():
s = tvm.create_schedule(A1.op) s = tvm.create_schedule(A1.op)
xo, xi = s[A1].split(A1.op.axis[0], 8) xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
export PYTHONPATH=python export PYTHONPATH=python
rm -rf python/tvm/*.pyc
TVM_FFI=ctypes python -m nose -v tests/python/unittest || exit -1 TVM_FFI=ctypes python -m nose -v tests/python/unittest || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/unittest || exit -1 TVM_FFI=ctypes python3 -m nose -v tests/python/unittest || exit -1
make cython || exit -1 make cython || exit -1
......
...@@ -40,7 +40,7 @@ def test_rpc_array(): ...@@ -40,7 +40,7 @@ def test_rpc_array():
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
cost = time_f(a, b) cost = time_f(a, b).mean
print('%g secs/op' % cost) print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote() check_remote()
......
...@@ -115,13 +115,13 @@ def test_depthwise_conv2d_map(): ...@@ -115,13 +115,13 @@ def test_depthwise_conv2d_map():
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
# Measure time cost of kernel 1 (depthwise_conv2d) # Measure time cost of kernel 1 (depthwise_conv2d)
timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=10000) timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=10000)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm) tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
# Measure time cost of kernel 2 (depthwise_conv2d + scale_shift) # Measure time cost of kernel 2 (depthwise_conv2d + scale_shift)
timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=10000) timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=10000)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm) tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
# Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu) # Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=10000) timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=10000)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm) tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
print("Input shape = " + str(get_const_tuple(Input.shape))) print("Input shape = " + str(get_const_tuple(Input.shape)))
print("Filter shape = " + str(get_const_tuple(Filter.shape))) print("Filter shape = " + str(get_const_tuple(Filter.shape)))
print("Stride = (%d, %d)" % (stride_h, stride_w)) print("Stride = (%d, %d)" % (stride_h, stride_w))
......
...@@ -79,13 +79,13 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul ...@@ -79,13 +79,13 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
# launch kernel 1 (depthwise_conv2d) # launch kernel 1 (depthwise_conv2d)
timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1) timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm) tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
# launch kernel 2 (depthwise_conv2d + scale_shift) # launch kernel 2 (depthwise_conv2d + scale_shift)
timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1) timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm) tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
# launch kernel 3 (depthwise_conv2d + scale_shift + relu) # launch kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1) timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm) tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
# correctness with scipy # correctness with scipy
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np) depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
......
...@@ -151,7 +151,7 @@ np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) ...@@ -151,7 +151,7 @@ np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
# device and returns the measured cost. # device and returns the measured cost.
# #
time_f = f.time_evaluator(f.entry_name, ctx, number=10) time_f = f.time_evaluator(f.entry_name, ctx, number=10)
cost = time_f(a, b) cost = time_f(a, b).mean
print('%g secs/op' % cost) print('%g secs/op' % cost)
# terminate the server after experiment # terminate the server after experiment
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment