Unverified Commit 11dd933f by Tianqi Chen Committed by GitHub

[RUNTIME] Enable return NDArray in RPC (#1610)

parent 9bcc3173
......@@ -246,6 +246,7 @@ struct NDArray::Container {
private:
friend class NDArray;
friend class RPCWrappedFunc;
/*!
* \brief The shape container,
* can be used used for shape data.
......
......@@ -37,6 +37,14 @@ TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
// internal fucntion used for debug and testing purposes
TVM_REGISTER_API("_ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) {
runtime::NDArray nd = args[0];
// substract the current one
*ret = (nd.use_count() - 1);
});
TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMSetStream(args[0], args[1], args[2]);
......
......@@ -12,13 +12,13 @@ namespace tvm {
namespace runtime {
// Wrapped remote function to packed func.
struct RPCWrappedFunc {
class RPCWrappedFunc {
public:
RPCWrappedFunc(void* handle,
std::shared_ptr<RPCSession> sess)
: handle_(handle), sess_(sess) {
fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
WrapRemote(sess, args.values[0].v_handle, args.type_codes[0], rv);
WrapRemote(sess, args, rv);
});
}
......@@ -34,10 +34,47 @@ struct RPCWrappedFunc {
}
static void WrapRemote(std::shared_ptr<RPCSession> sess,
void* handle,
int tcode,
TVMArgs args,
TVMRetValue* rv);
// deleter of RPC remote array
static void RemoteNDArrayDeleter(NDArray::Container* ptr) {
RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx);
delete space;
delete ptr;
}
// wrap return value as remote NDArray.
static NDArray WrapRemoteNDArray(std::shared_ptr<RPCSession> sess,
DLTensor* tensor,
void* nd_handle) {
NDArray::Container* data = new NDArray::Container();
data->manager_ctx = nd_handle;
data->deleter = RemoteNDArrayDeleter;
RemoteSpace* space = new RemoteSpace();
space->sess = sess;
space->data = tensor->data;
data->dl_tensor.data = space;
NDArray ret(data);
// RAII now in effect
data->shape_ = std::vector<int64_t>(
tensor->shape, tensor->shape + tensor->ndim);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
// setup dtype
data->dl_tensor.dtype = tensor->dtype;
// setup ctx, encode as remote session
data->dl_tensor.ctx.device_id = tensor->ctx.device_id;
data->dl_tensor.ctx.device_type = static_cast<DLDeviceType>(
static_cast<int>(tensor->ctx.device_type) +
kRPCSessMask * (sess->table_index() + 1));
// check strides.
CHECK(tensor->strides == nullptr);
// setup byteoffset
data->dl_tensor.byte_offset = tensor->byte_offset;
return ret;
}
private:
PackedFunc fwrap_;
void* handle_{nullptr};
......@@ -126,20 +163,28 @@ class RPCModuleNode final : public ModuleNode {
};
void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
void* handle,
int tcode,
TVMArgs args,
TVMRetValue *rv) {
void* handle = args.values[0].v_handle;
int tcode = args.type_codes[0];
if (handle == nullptr) return;
if (tcode == kFuncHandle) {
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess);
*rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv);
});
} else {
CHECK_EQ(tcode, kModuleHandle);
} else if (tcode == kModuleHandle) {
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(handle, sess);
*rv = Module(n);
} else if (tcode == kArrayHandle || tcode == kNDArrayContainer) {
CHECK_EQ(args.size(), 2);
DLTensor* tensor = args[0];
void* nd_handle = args[1];
*rv = WrapRemoteNDArray(sess, tensor, nd_handle);
} else {
LOG(FATAL) << "Cannot wrap tcode=" << tcode;
}
}
......
......@@ -130,19 +130,22 @@ class RPCSession::EventHandler : public dmlc::Stream {
break;
}
case kReturnReceived: {
CHECK_EQ(arg_buf_->value.size(), 1U);
CHECK_GE(arg_buf_->value.size(), 1U);
TVMArgValue argv = arg_buf_->AsTVMArgs()[0];
if (argv.type_code() == kFuncHandle ||
argv.type_code() == kModuleHandle) {
argv.type_code() == kModuleHandle ||
argv.type_code() == kArrayHandle) {
CHECK(fwrap != nullptr) << "function/module wrapper not available";
fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv);
} else {
CHECK_EQ(arg_buf_->value.size(), 1U);
*rv = argv;
}
arg_buf_.reset();
this->SwitchToState(kRecvCode);
std::swap(client_mode_, client_mode);
return RPCCode::kReturn;
return RPCCode::kReturn;
}
case kCopyAckReceived: {
std::swap(client_mode_, client_mode);
......@@ -172,15 +175,22 @@ class RPCSession::EventHandler : public dmlc::Stream {
ctx.device_type = static_cast<DLDeviceType>(dev_type % kRPCSessMask);
return ctx;
}
// send Packed sequence to writer.
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n) {
// Send Packed sequence to writer.
// return_ndarray is a special flag to handle returning of ndarray
// In this case, we return the shape, context and data of the array,
// as well as a customized PackedFunc that handles deletion of
// the array in the remote.
void SendPackedSeq(const TVMValue* arg_values,
const int* type_codes,
int n,
bool return_ndarray = false) {
this->Write(n);
// only handles .
for (int i = 0; i < n; ++i) {
int tcode = type_codes[i];
if (tcode == kNDArrayContainer) tcode = kArrayHandle;
this->Write(tcode);
}
// Argument packing.
for (int i = 0; i < n; ++i) {
int tcode = type_codes[i];
......@@ -215,9 +225,23 @@ class RPCSession::EventHandler : public dmlc::Stream {
case kNDArrayContainer:
case kArrayHandle: {
DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
TVMContext ctx = StripSessMask(arr->ctx);
uint64_t data = reinterpret_cast<uint64_t>(
static_cast<RemoteSpace*>(arr->data)->data);
TVMContext ctx;
uint64_t data;
if (!return_ndarray) {
// in the client mode
// ctx contains the remote table index
// the space is wrapped by an RemoteSpace
// that holds reference to the session.
ctx = StripSessMask(arr->ctx);
data = reinterpret_cast<uint64_t>(
static_cast<RemoteSpace*>(arr->data)->data);
} else {
// When we return NDArray, we directly return
// the space and the context
// The client will be further wrapping
ctx = arr->ctx;
data = reinterpret_cast<uint64_t>(arr->data);
}
this->Write(data);
this->Write(ctx);
this->Write(arr->ndim);
......@@ -701,6 +725,21 @@ class RPCSession::EventHandler : public dmlc::Stream {
<< "Only server can send function and module handle back.";
rv.MoveToCHost(&ret_value, &ret_tcode);
SendPackedSeq(&ret_value, &ret_tcode, 1);
} else if (rv.type_code() == kNDArrayContainer) {
// always send handle in 64 bit.
CHECK(!client_mode_)
<< "Only server can send NDArray back";
// We follow a special protocol to return NDArray to client side
// The first pack value is the NDArray handle as DLTensor
// The second pack value is a customized deleter that deletes the NDArray.
TVMValue ret_value_pack[2];
int ret_tcode_pack[2];
rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]);
NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
ret_value_pack[1].v_handle = nd;
ret_tcode_pack[1] = kHandle;
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, true);
} else {
ret_value = rv.value();
ret_tcode = rv.type_code();
......@@ -1090,6 +1129,11 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
*rv = (*static_cast<Module*>(mhandle))->GetSource(fmt);
}
void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
void* handle = args[0];
static_cast<NDArray::Container*>(handle)->DecRef();
}
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3]));
......@@ -1138,6 +1182,7 @@ void RPCSession::EventHandler::HandlePackedCall() {
case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break;
default: LOG(FATAL) << "Unknown event " << static_cast<int>(code_);
}
CHECK_EQ(state_, kRecvCode);
......
......@@ -48,6 +48,7 @@ enum class RPCCode : int {
kModuleFree,
kModuleGetFunc,
kModuleGetSource,
kNDArrayFree
};
/*!
......
......@@ -175,6 +175,7 @@ def test_rpc_return_func():
@tvm.register_func("rpc.test.remote_func")
def addone(x):
return lambda y: x+y
server = rpc.Server("localhost", key="x1")
client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.remote_func")
......@@ -182,6 +183,46 @@ def test_rpc_return_func():
assert fadd(12) == 22
def test_rpc_return_ndarray():
# Use closure to check the ref counter correctness
nd = tvm.nd.array(np.zeros(10).astype("float32"))
@tvm.register_func("rpc.test.remote_return_nd")
def my_module(name):
if name == "get_arr":
return lambda : nd
elif name == "ref_count":
return lambda : tvm._api_internal._ndarray_use_count(nd)
elif name == "get_elem":
return lambda idx: nd.asnumpy()[idx]
elif name == "get_arr_elem":
return lambda arr, idx: arr.asnumpy()[idx]
# start server
server = rpc.Server("localhost", key="x1")
client = rpc.connect(server.host, server.port, key="x1")
m = client.get_function("rpc.test.remote_return_nd")
get_arr = m("get_arr")
ref_count = m("ref_count")
get_elem = m("get_elem")
get_arr_elem = m("get_arr_elem")
# array test
def run_arr_test():
arr = get_arr()
assert ref_count() == 2
arr2 = get_arr()
assert ref_count() == 3
assert arr.context == client.cpu(0)
arr.copyfrom(np.ones(10).astype(arr.dtype))
assert arr2.asnumpy()[0] == 1.0
assert get_elem(0) == 1.0
assert get_arr_elem(arr2, 0) == 1.0
assert ref_count() == 1
run_arr_test()
# check recycle correctness
assert ref_count() == 1
def test_local_func():
@tvm.register_func("rpc.test.remote_func2")
def addone(x):
......@@ -199,9 +240,10 @@ def test_local_func():
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
test_rpc_return_ndarray()
test_rpc_return_func()
test_bigendian_rpc()
test_rpc_remote_module()
test_rpc_return_func()
test_rpc_file_exchange()
test_rpc_array()
test_rpc_simple()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment