Unverified Commit 11dd933f by Tianqi Chen Committed by GitHub

[RUNTIME] Enable return NDArray in RPC (#1610)

parent 9bcc3173
...@@ -246,6 +246,7 @@ struct NDArray::Container { ...@@ -246,6 +246,7 @@ struct NDArray::Container {
private: private:
friend class NDArray; friend class NDArray;
friend class RPCWrappedFunc;
/*! /*!
* \brief The shape container, * \brief The shape container,
* can be used used for shape data. * can be used used for shape data.
......
...@@ -37,6 +37,14 @@ TVM_REGISTER_API("_nop") ...@@ -37,6 +37,14 @@ TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
}); });
// internal fucntion used for debug and testing purposes
TVM_REGISTER_API("_ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) {
runtime::NDArray nd = args[0];
// substract the current one
*ret = (nd.use_count() - 1);
});
TVM_REGISTER_API("_TVMSetStream") TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
TVMSetStream(args[0], args[1], args[2]); TVMSetStream(args[0], args[1], args[2]);
......
...@@ -12,13 +12,13 @@ namespace tvm { ...@@ -12,13 +12,13 @@ namespace tvm {
namespace runtime { namespace runtime {
// Wrapped remote function to packed func. // Wrapped remote function to packed func.
struct RPCWrappedFunc { class RPCWrappedFunc {
public: public:
RPCWrappedFunc(void* handle, RPCWrappedFunc(void* handle,
std::shared_ptr<RPCSession> sess) std::shared_ptr<RPCSession> sess)
: handle_(handle), sess_(sess) { : handle_(handle), sess_(sess) {
fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
WrapRemote(sess, args.values[0].v_handle, args.type_codes[0], rv); WrapRemote(sess, args, rv);
}); });
} }
...@@ -34,10 +34,47 @@ struct RPCWrappedFunc { ...@@ -34,10 +34,47 @@ struct RPCWrappedFunc {
} }
static void WrapRemote(std::shared_ptr<RPCSession> sess, static void WrapRemote(std::shared_ptr<RPCSession> sess,
void* handle, TVMArgs args,
int tcode,
TVMRetValue* rv); TVMRetValue* rv);
// deleter of RPC remote array
static void RemoteNDArrayDeleter(NDArray::Container* ptr) {
RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx);
delete space;
delete ptr;
}
// wrap return value as remote NDArray.
static NDArray WrapRemoteNDArray(std::shared_ptr<RPCSession> sess,
DLTensor* tensor,
void* nd_handle) {
NDArray::Container* data = new NDArray::Container();
data->manager_ctx = nd_handle;
data->deleter = RemoteNDArrayDeleter;
RemoteSpace* space = new RemoteSpace();
space->sess = sess;
space->data = tensor->data;
data->dl_tensor.data = space;
NDArray ret(data);
// RAII now in effect
data->shape_ = std::vector<int64_t>(
tensor->shape, tensor->shape + tensor->ndim);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
// setup dtype
data->dl_tensor.dtype = tensor->dtype;
// setup ctx, encode as remote session
data->dl_tensor.ctx.device_id = tensor->ctx.device_id;
data->dl_tensor.ctx.device_type = static_cast<DLDeviceType>(
static_cast<int>(tensor->ctx.device_type) +
kRPCSessMask * (sess->table_index() + 1));
// check strides.
CHECK(tensor->strides == nullptr);
// setup byteoffset
data->dl_tensor.byte_offset = tensor->byte_offset;
return ret;
}
private: private:
PackedFunc fwrap_; PackedFunc fwrap_;
void* handle_{nullptr}; void* handle_{nullptr};
...@@ -126,20 +163,28 @@ class RPCModuleNode final : public ModuleNode { ...@@ -126,20 +163,28 @@ class RPCModuleNode final : public ModuleNode {
}; };
void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess, void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
void* handle, TVMArgs args,
int tcode,
TVMRetValue *rv) { TVMRetValue *rv) {
void* handle = args.values[0].v_handle;
int tcode = args.type_codes[0];
if (handle == nullptr) return; if (handle == nullptr) return;
if (tcode == kFuncHandle) { if (tcode == kFuncHandle) {
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess); auto wf = std::make_shared<RPCWrappedFunc>(handle, sess);
*rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv); return wf->operator()(args, rv);
}); });
} else { } else if (tcode == kModuleHandle) {
CHECK_EQ(tcode, kModuleHandle);
std::shared_ptr<RPCModuleNode> n = std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(handle, sess); std::make_shared<RPCModuleNode>(handle, sess);
*rv = Module(n); *rv = Module(n);
} else if (tcode == kArrayHandle || tcode == kNDArrayContainer) {
CHECK_EQ(args.size(), 2);
DLTensor* tensor = args[0];
void* nd_handle = args[1];
*rv = WrapRemoteNDArray(sess, tensor, nd_handle);
} else {
LOG(FATAL) << "Cannot wrap tcode=" << tcode;
} }
} }
......
...@@ -130,13 +130,16 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -130,13 +130,16 @@ class RPCSession::EventHandler : public dmlc::Stream {
break; break;
} }
case kReturnReceived: { case kReturnReceived: {
CHECK_EQ(arg_buf_->value.size(), 1U); CHECK_GE(arg_buf_->value.size(), 1U);
TVMArgValue argv = arg_buf_->AsTVMArgs()[0]; TVMArgValue argv = arg_buf_->AsTVMArgs()[0];
if (argv.type_code() == kFuncHandle || if (argv.type_code() == kFuncHandle ||
argv.type_code() == kModuleHandle) { argv.type_code() == kModuleHandle ||
argv.type_code() == kArrayHandle) {
CHECK(fwrap != nullptr) << "function/module wrapper not available"; CHECK(fwrap != nullptr) << "function/module wrapper not available";
fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv); fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv);
} else { } else {
CHECK_EQ(arg_buf_->value.size(), 1U);
*rv = argv; *rv = argv;
} }
arg_buf_.reset(); arg_buf_.reset();
...@@ -172,15 +175,22 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -172,15 +175,22 @@ class RPCSession::EventHandler : public dmlc::Stream {
ctx.device_type = static_cast<DLDeviceType>(dev_type % kRPCSessMask); ctx.device_type = static_cast<DLDeviceType>(dev_type % kRPCSessMask);
return ctx; return ctx;
} }
// send Packed sequence to writer. // Send Packed sequence to writer.
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n) { // return_ndarray is a special flag to handle returning of ndarray
// In this case, we return the shape, context and data of the array,
// as well as a customized PackedFunc that handles deletion of
// the array in the remote.
void SendPackedSeq(const TVMValue* arg_values,
const int* type_codes,
int n,
bool return_ndarray = false) {
this->Write(n); this->Write(n);
// only handles .
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
int tcode = type_codes[i]; int tcode = type_codes[i];
if (tcode == kNDArrayContainer) tcode = kArrayHandle; if (tcode == kNDArrayContainer) tcode = kArrayHandle;
this->Write(tcode); this->Write(tcode);
} }
// Argument packing. // Argument packing.
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
int tcode = type_codes[i]; int tcode = type_codes[i];
...@@ -215,9 +225,23 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -215,9 +225,23 @@ class RPCSession::EventHandler : public dmlc::Stream {
case kNDArrayContainer: case kNDArrayContainer:
case kArrayHandle: { case kArrayHandle: {
DLTensor* arr = static_cast<DLTensor*>(value.v_handle); DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
TVMContext ctx = StripSessMask(arr->ctx); TVMContext ctx;
uint64_t data = reinterpret_cast<uint64_t>( uint64_t data;
if (!return_ndarray) {
// in the client mode
// ctx contains the remote table index
// the space is wrapped by an RemoteSpace
// that holds reference to the session.
ctx = StripSessMask(arr->ctx);
data = reinterpret_cast<uint64_t>(
static_cast<RemoteSpace*>(arr->data)->data); static_cast<RemoteSpace*>(arr->data)->data);
} else {
// When we return NDArray, we directly return
// the space and the context
// The client will be further wrapping
ctx = arr->ctx;
data = reinterpret_cast<uint64_t>(arr->data);
}
this->Write(data); this->Write(data);
this->Write(ctx); this->Write(ctx);
this->Write(arr->ndim); this->Write(arr->ndim);
...@@ -701,6 +725,21 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -701,6 +725,21 @@ class RPCSession::EventHandler : public dmlc::Stream {
<< "Only server can send function and module handle back."; << "Only server can send function and module handle back.";
rv.MoveToCHost(&ret_value, &ret_tcode); rv.MoveToCHost(&ret_value, &ret_tcode);
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1);
} else if (rv.type_code() == kNDArrayContainer) {
// always send handle in 64 bit.
CHECK(!client_mode_)
<< "Only server can send NDArray back";
// We follow a special protocol to return NDArray to client side
// The first pack value is the NDArray handle as DLTensor
// The second pack value is a customized deleter that deletes the NDArray.
TVMValue ret_value_pack[2];
int ret_tcode_pack[2];
rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]);
NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
ret_value_pack[1].v_handle = nd;
ret_tcode_pack[1] = kHandle;
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, true);
} else { } else {
ret_value = rv.value(); ret_value = rv.value();
ret_tcode = rv.type_code(); ret_tcode = rv.type_code();
...@@ -1090,6 +1129,11 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { ...@@ -1090,6 +1129,11 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
*rv = (*static_cast<Module*>(mhandle))->GetSource(fmt); *rv = (*static_cast<Module*>(mhandle))->GetSource(fmt);
} }
void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
void* handle = args[0];
static_cast<NDArray::Container*>(handle)->DecRef();
}
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*()); PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3])); void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3]));
...@@ -1138,6 +1182,7 @@ void RPCSession::EventHandler::HandlePackedCall() { ...@@ -1138,6 +1182,7 @@ void RPCSession::EventHandler::HandlePackedCall() {
case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break; case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break; case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break; case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break;
default: LOG(FATAL) << "Unknown event " << static_cast<int>(code_); default: LOG(FATAL) << "Unknown event " << static_cast<int>(code_);
} }
CHECK_EQ(state_, kRecvCode); CHECK_EQ(state_, kRecvCode);
......
...@@ -48,6 +48,7 @@ enum class RPCCode : int { ...@@ -48,6 +48,7 @@ enum class RPCCode : int {
kModuleFree, kModuleFree,
kModuleGetFunc, kModuleGetFunc,
kModuleGetSource, kModuleGetSource,
kNDArrayFree
}; };
/*! /*!
......
...@@ -175,6 +175,7 @@ def test_rpc_return_func(): ...@@ -175,6 +175,7 @@ def test_rpc_return_func():
@tvm.register_func("rpc.test.remote_func") @tvm.register_func("rpc.test.remote_func")
def addone(x): def addone(x):
return lambda y: x+y return lambda y: x+y
server = rpc.Server("localhost", key="x1") server = rpc.Server("localhost", key="x1")
client = rpc.connect(server.host, server.port, key="x1") client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.remote_func") f1 = client.get_function("rpc.test.remote_func")
...@@ -182,6 +183,46 @@ def test_rpc_return_func(): ...@@ -182,6 +183,46 @@ def test_rpc_return_func():
assert fadd(12) == 22 assert fadd(12) == 22
def test_rpc_return_ndarray():
# Use closure to check the ref counter correctness
nd = tvm.nd.array(np.zeros(10).astype("float32"))
@tvm.register_func("rpc.test.remote_return_nd")
def my_module(name):
if name == "get_arr":
return lambda : nd
elif name == "ref_count":
return lambda : tvm._api_internal._ndarray_use_count(nd)
elif name == "get_elem":
return lambda idx: nd.asnumpy()[idx]
elif name == "get_arr_elem":
return lambda arr, idx: arr.asnumpy()[idx]
# start server
server = rpc.Server("localhost", key="x1")
client = rpc.connect(server.host, server.port, key="x1")
m = client.get_function("rpc.test.remote_return_nd")
get_arr = m("get_arr")
ref_count = m("ref_count")
get_elem = m("get_elem")
get_arr_elem = m("get_arr_elem")
# array test
def run_arr_test():
arr = get_arr()
assert ref_count() == 2
arr2 = get_arr()
assert ref_count() == 3
assert arr.context == client.cpu(0)
arr.copyfrom(np.ones(10).astype(arr.dtype))
assert arr2.asnumpy()[0] == 1.0
assert get_elem(0) == 1.0
assert get_arr_elem(arr2, 0) == 1.0
assert ref_count() == 1
run_arr_test()
# check recycle correctness
assert ref_count() == 1
def test_local_func(): def test_local_func():
@tvm.register_func("rpc.test.remote_func2") @tvm.register_func("rpc.test.remote_func2")
def addone(x): def addone(x):
...@@ -199,9 +240,10 @@ def test_local_func(): ...@@ -199,9 +240,10 @@ def test_local_func():
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
test_rpc_return_ndarray()
test_rpc_return_func()
test_bigendian_rpc() test_bigendian_rpc()
test_rpc_remote_module() test_rpc_remote_module()
test_rpc_return_func()
test_rpc_file_exchange() test_rpc_file_exchange()
test_rpc_array() test_rpc_array()
test_rpc_simple() test_rpc_simple()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment