Unverified Commit 279a8eba by Tianqi Chen Committed by GitHub

[RUNTIME][RPC] Update RPC runtime to allow remote module as arg (#4462)

parent 77bdd5f7
...@@ -23,7 +23,6 @@ from tvm._ffi.base import string_types ...@@ -23,7 +23,6 @@ from tvm._ffi.base import string_types
from tvm._ffi.function import get_global_func from tvm._ffi.function import get_global_func
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.ndarray import array from tvm.ndarray import array
from tvm.rpc import base as rpc_base
from . import debug_result from . import debug_result
_DUMP_ROOT_PREFIX = "tvmdbg_" _DUMP_ROOT_PREFIX = "tvmdbg_"
...@@ -60,25 +59,17 @@ def create(graph_json_str, libmod, ctx, dump_root=None): ...@@ -60,25 +59,17 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
except AttributeError: except AttributeError:
raise ValueError("Type %s is not supported" % type(graph_json_str)) raise ValueError("Type %s is not supported" % type(graph_json_str))
try: try:
fcreate = get_global_func("tvm.graph_runtime_debug.create") ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
fcreate = ctx[0]._rpc_sess.get_function(
"tvm.graph_runtime_debug.create")
else:
fcreate = get_global_func("tvm.graph_runtime_debug.create")
except ValueError: except ValueError:
raise ValueError( raise ValueError(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " "Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"config.cmake and rebuild TVM to enable debug mode" "config.cmake and rebuild TVM to enable debug mode"
) )
ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
libmod = rpc_base._ModuleHandle(libmod)
try:
fcreate = ctx[0]._rpc_sess.get_function(
"tvm.graph_runtime_debug.remote_create"
)
except ValueError:
raise ValueError(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"config.cmake and rebuild TVM to enable debug mode"
)
func_obj = fcreate(graph_json_str, libmod, *device_type_id) func_obj = fcreate(graph_json_str, libmod, *device_type_id)
return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root) return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root)
......
...@@ -51,11 +51,10 @@ def create(graph_json_str, libmod, ctx): ...@@ -51,11 +51,10 @@ def create(graph_json_str, libmod, ctx):
ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx) ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx): if num_rpc_ctx == len(ctx):
hmod = rpc_base._ModuleHandle(libmod) fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create")
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create") else:
return GraphModule(fcreate(graph_json_str, hmod, *device_type_id)) fcreate = get_global_func("tvm.graph_runtime.create")
fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id)) return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
def get_device_ctx(libmod, ctx): def get_device_ctx(libmod, ctx):
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <chrono> #include <chrono>
#include <sstream> #include <sstream>
#include "../graph_runtime.h" #include "../graph_runtime.h"
#include "../../object_internal.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -220,19 +219,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create") ...@@ -220,19 +219,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
<< args.num_args; << args.num_args;
*rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args)); *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args));
}); });
TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<< args.num_args;
void* mhandle = args[1];
ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeDebugCreate(
args[0], GetRef<Module>(mnode), contexts);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -36,7 +36,6 @@ ...@@ -36,7 +36,6 @@
#include <vector> #include <vector>
#include "graph_runtime.h" #include "graph_runtime.h"
#include "../object_internal.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -511,19 +510,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create") ...@@ -511,19 +510,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
const auto& contexts = GetAllContext(args); const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate(args[0], args[1], contexts); *rv = GraphRuntimeCreate(args[0], args[1], contexts);
}); });
TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<< args.num_args;
void* mhandle = args[1];
ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate(
args[0], GetRef<Module>(mnode), contexts);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -41,7 +41,7 @@ class RPCWrappedFunc { ...@@ -41,7 +41,7 @@ class RPCWrappedFunc {
} }
void operator()(TVMArgs args, TVMRetValue *rv) const { void operator()(TVMArgs args, TVMRetValue *rv) const {
sess_->CallFunc(handle_, args, rv, &fwrap_); sess_->CallFunc(handle_, args, rv, UnwrapRemote, &fwrap_);
} }
~RPCWrappedFunc() { ~RPCWrappedFunc() {
try { try {
...@@ -55,6 +55,9 @@ class RPCWrappedFunc { ...@@ -55,6 +55,9 @@ class RPCWrappedFunc {
TVMArgs args, TVMArgs args,
TVMRetValue* rv); TVMRetValue* rv);
static void* UnwrapRemote(int rpc_sess_table_index,
const TVMArgValue& arg);
// deleter of RPC remote array // deleter of RPC remote array
static void RemoteNDArrayDeleter(NDArray::Container* ptr) { static void RemoteNDArrayDeleter(NDArray::Container* ptr) {
RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data); RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
...@@ -181,6 +184,25 @@ class RPCModuleNode final : public ModuleNode { ...@@ -181,6 +184,25 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc fwrap_; PackedFunc fwrap_;
}; };
void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index,
const TVMArgValue& arg) {
if (arg.type_code() == kModuleHandle) {
Module mod = arg;
std::string tkey = mod->type_key();
CHECK_EQ(tkey, "rpc")
<< "ValueError: Cannot pass a non-RPC module to remote";
auto* rmod = static_cast<RPCModuleNode*>(mod.operator->());
CHECK_EQ(rmod->sess()->table_index(), rpc_sess_table_index)
<< "ValueError: Cannot pass in module into a different remote session";
return rmod->module_handle();
} else {
LOG(FATAL) << "ValueError: Cannot pass type "
<< runtime::TypeCode2Str(arg.type_code())
<< " as an argument to the remote";
return nullptr;
}
}
void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess, void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
TVMArgs args, TVMArgs args,
TVMRetValue *rv) { TVMRetValue *rv) {
......
...@@ -202,23 +202,33 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -202,23 +202,33 @@ class RPCSession::EventHandler : public dmlc::Stream {
return ctx; return ctx;
} }
// Send Packed sequence to writer. // Send Packed sequence to writer.
//
// client_mode: whether we are in client mode.
//
// funwrap: auxiliary function to unwrap remote Object
// when it is provided, we need to unwrap objects.
//
// return_ndarray is a special flag to handle returning of ndarray // return_ndarray is a special flag to handle returning of ndarray
// In this case, we return the shape, context and data of the array, // In this case, we return the shape, context and data of the array,
// as well as a customized PackedFunc that handles deletion of // as well as a customized PackedFunc that handles deletion of
// the array in the remote. // the array in the remote.
void SendPackedSeq(const TVMValue* arg_values, void SendPackedSeq(const TVMValue* arg_values,
const int* type_codes, const int* type_codes,
int n, int num_args,
bool client_mode,
FUnwrapRemoteObject funwrap = nullptr,
bool return_ndarray = false) { bool return_ndarray = false) {
this->Write(n); std::swap(client_mode_, client_mode);
for (int i = 0; i < n; ++i) {
this->Write(num_args);
for (int i = 0; i < num_args; ++i) {
int tcode = type_codes[i]; int tcode = type_codes[i];
if (tcode == kNDArrayContainer) tcode = kArrayHandle; if (tcode == kNDArrayContainer) tcode = kArrayHandle;
this->Write(tcode); this->Write(tcode);
} }
// Argument packing. // Argument packing.
for (int i = 0; i < n; ++i) { for (int i = 0; i < num_args; ++i) {
int tcode = type_codes[i]; int tcode = type_codes[i];
TVMValue value = arg_values[i]; TVMValue value = arg_values[i];
switch (tcode) { switch (tcode) {
...@@ -241,7 +251,23 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -241,7 +251,23 @@ class RPCSession::EventHandler : public dmlc::Stream {
break; break;
} }
case kFuncHandle: case kFuncHandle:
case kModuleHandle: case kModuleHandle: {
// always send handle in 64 bit.
uint64_t handle;
// allow pass module as argument to remote.
if (funwrap != nullptr) {
void* remote_handle = (*funwrap)(
rpc_sess_table_index_,
runtime::TVMArgValue(value, tcode));
handle = reinterpret_cast<uint64_t>(remote_handle);
} else {
CHECK(!client_mode_)
<< "Cannot directly pass remote object as argument";
handle = reinterpret_cast<uint64_t>(value.v_handle);
}
this->Write(handle);
break;
}
case kHandle: { case kHandle: {
// always send handle in 64 bit. // always send handle in 64 bit.
uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle); uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
...@@ -300,6 +326,7 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -300,6 +326,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
} }
} }
} }
std::swap(client_mode_, client_mode);
} }
// Endian aware IO handling // Endian aware IO handling
...@@ -430,11 +457,11 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -430,11 +457,11 @@ class RPCSession::EventHandler : public dmlc::Stream {
case kHandle: case kHandle:
case kStr: case kStr:
case kBytes: case kBytes:
case kModuleHandle:
case kTVMContext: { case kTVMContext: {
this->RequestBytes(sizeof(TVMValue)); break; this->RequestBytes(sizeof(TVMValue)); break;
} }
case kFuncHandle: case kFuncHandle: {
case kModuleHandle: {
CHECK(client_mode_) CHECK(client_mode_)
<< "Only client can receive remote functions"; << "Only client can receive remote functions";
this->RequestBytes(sizeof(TVMValue)); break; this->RequestBytes(sizeof(TVMValue)); break;
...@@ -656,7 +683,7 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -656,7 +683,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
TVMValue ret_value; TVMValue ret_value;
ret_value.v_str = e.what(); ret_value.v_str = e.what();
int ret_tcode = kStr; int ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} }
} }
this->SwitchToState(kRecvCode); this->SwitchToState(kRecvCode);
...@@ -711,7 +738,7 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -711,7 +738,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
} }
} }
this->Write(code); this->Write(code);
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1, false);
arg_recv_stage_ = 0; arg_recv_stage_ = 0;
this->SwitchToState(kRecvCode); this->SwitchToState(kRecvCode);
} }
...@@ -734,7 +761,7 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -734,7 +761,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
if (rv.type_code() == kStr) { if (rv.type_code() == kStr) {
ret_value.v_str = rv.ptr<std::string>()->c_str(); ret_value.v_str = rv.ptr<std::string>()->c_str();
ret_tcode = kStr; ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} else if (rv.type_code() == kBytes) { } else if (rv.type_code() == kBytes) {
std::string* bytes = rv.ptr<std::string>(); std::string* bytes = rv.ptr<std::string>();
TVMByteArray arr; TVMByteArray arr;
...@@ -742,14 +769,14 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -742,14 +769,14 @@ class RPCSession::EventHandler : public dmlc::Stream {
arr.size = bytes->length(); arr.size = bytes->length();
ret_value.v_handle = &arr; ret_value.v_handle = &arr;
ret_tcode = kBytes; ret_tcode = kBytes;
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} else if (rv.type_code() == kFuncHandle || } else if (rv.type_code() == kFuncHandle ||
rv.type_code() == kModuleHandle) { rv.type_code() == kModuleHandle) {
// always send handle in 64 bit. // always send handle in 64 bit.
CHECK(!client_mode_) CHECK(!client_mode_)
<< "Only server can send function and module handle back."; << "Only server can send function and module handle back.";
rv.MoveToCHost(&ret_value, &ret_tcode); rv.MoveToCHost(&ret_value, &ret_tcode);
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} else if (rv.type_code() == kNDArrayContainer) { } else if (rv.type_code() == kNDArrayContainer) {
// always send handle in 64 bit. // always send handle in 64 bit.
CHECK(!client_mode_) CHECK(!client_mode_)
...@@ -764,18 +791,18 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -764,18 +791,18 @@ class RPCSession::EventHandler : public dmlc::Stream {
NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle); NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
ret_value_pack[1].v_handle = nd; ret_value_pack[1].v_handle = nd;
ret_tcode_pack[1] = kHandle; ret_tcode_pack[1] = kHandle;
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, true); SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true);
} else { } else {
ret_value = rv.value(); ret_value = rv.value();
ret_tcode = rv.type_code(); ret_tcode = rv.type_code();
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} }
} catch (const std::runtime_error& e) { } catch (const std::runtime_error& e) {
RPCCode code = RPCCode::kException; RPCCode code = RPCCode::kException;
this->Write(code); this->Write(code);
ret_value.v_str = e.what(); ret_value.v_str = e.what();
ret_tcode = kStr; ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} }
} }
...@@ -873,7 +900,7 @@ void RPCSession::Init() { ...@@ -873,7 +900,7 @@ void RPCSession::Init() {
&reader_, &writer_, table_index_, name_, &remote_key_); &reader_, &writer_, table_index_, name_, &remote_key_);
// Quick function to call remote. // Quick function to call remote.
call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) { call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args); handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true);
RPCCode code = HandleUntilReturnEvent(rv, true, nullptr); RPCCode code = HandleUntilReturnEvent(rv, true, nullptr);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code); CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
}); });
...@@ -954,13 +981,16 @@ int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) { ...@@ -954,13 +981,16 @@ int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
void RPCSession::CallFunc(void* h, void RPCSession::CallFunc(void* h,
TVMArgs args, TVMArgs args,
TVMRetValue* rv, TVMRetValue* rv,
FUnwrapRemoteObject funwrap,
const PackedFunc* fwrap) { const PackedFunc* fwrap) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
RPCCode code = RPCCode::kCallFunc; RPCCode code = RPCCode::kCallFunc;
handler_->Write(code); handler_->Write(code);
uint64_t handle = reinterpret_cast<uint64_t>(h); uint64_t handle = reinterpret_cast<uint64_t>(h);
handler_->Write(handle); handler_->Write(handle);
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args); handler_->SendPackedSeq(
args.values, args.type_codes, args.num_args, true, funwrap);
code = HandleUntilReturnEvent(rv, true, fwrap); code = HandleUntilReturnEvent(rv, true, fwrap);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code); CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
} }
......
...@@ -91,6 +91,16 @@ enum class RPCCode : int { ...@@ -91,6 +91,16 @@ enum class RPCCode : int {
}; };
/*! /*!
* \brief Function that unwraps a remote object to its handle.
* \param rpc_sess_table_index RPC session table index for validation.
* \param obj Handle to the object argument.
* \return The corresponding handle.
*/
typedef void* (*FUnwrapRemoteObject)(
int rpc_sess_table_index,
const TVMArgValue& obj);
/*!
* \brief Abstract channel interface used to create RPCSession. * \brief Abstract channel interface used to create RPCSession.
*/ */
class RPCChannel { class RPCChannel {
...@@ -144,11 +154,13 @@ class RPCSession { ...@@ -144,11 +154,13 @@ class RPCSession {
* \param handle The function handle * \param handle The function handle
* \param args The arguments * \param args The arguments
* \param rv The return value. * \param rv The return value.
* \param funpwrap Function that takes a remote object and returns the raw handle.
* \param fwrap Wrapper function to turn Function/Module handle into real return. * \param fwrap Wrapper function to turn Function/Module handle into real return.
*/ */
void CallFunc(RPCFuncHandle handle, void CallFunc(RPCFuncHandle handle,
TVMArgs args, TVMArgs args,
TVMRetValue* rv, TVMRetValue* rv,
FUnwrapRemoteObject funwrap,
const PackedFunc* fwrap); const PackedFunc* fwrap);
/*! /*!
* \brief Copy bytes into remote array content. * \brief Copy bytes into remote array content.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment