Commit 52f04517 by tqchen Committed by Tianqi Chen

[RUNTIME] update to make runtime copy type aware

parent 42608dda
...@@ -81,18 +81,21 @@ class DeviceAPI { ...@@ -81,18 +81,21 @@ class DeviceAPI {
* \param from_offset The byte offeset in the from. * \param from_offset The byte offeset in the from.
* \param to The target array. * \param to The target array.
* \param to_offset The byte offset in the to. * \param to_offset The byte offset in the to.
* \param size The size of the memory * \param num_bytes The size of the memory in bytes
* \param ctx_from The source context * \param ctx_from The source context
* \param ctx_to The target context * \param ctx_to The target context
* \param type_hint The type of elements, only neded by certain backends.
* can be useful for cross device endian converison.
* \param stream Optional stream object. * \param stream Optional stream object.
*/ */
virtual void CopyDataFromTo(const void* from, virtual void CopyDataFromTo(const void* from,
size_t from_offset, size_t from_offset,
void* to, void* to,
size_t to_offset, size_t to_offset,
size_t size, size_t num_bytes,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) = 0; TVMStreamHandle stream) = 0;
/*! /*!
* \brief Create a new stream of execution. * \brief Create a new stream of execution.
......
import os
import numpy as np import numpy as np
import nnvm.compiler import nnvm.compiler
import tvm
from tvm.contrib import rpc, util, graph_runtime
def test_save_load(): def test_save_load():
x = np.random.uniform(size=(10, 2)).astype("float32") x = np.random.uniform(size=(10, 2)).astype("float32")
...@@ -15,5 +19,45 @@ def test_save_load(): ...@@ -15,5 +19,45 @@ def test_save_load():
np.testing.assert_equal(param2["y"].asnumpy(), y) np.testing.assert_equal(param2["y"].asnumpy(), y)
def test_bigendian_rpc_param():
"""Test big endian rpc when there is a PowerPC RPC server available"""
host = os.environ.get("TVM_POWERPC_TEST_HOST", None)
port = os.environ.get("TVM_POWERPC_TEST_PORT", 9090)
if host is None:
return
def verify_nnvm(remote, target, shape, dtype):
x = nnvm.sym.Variable("x")
y = x + 1
graph, lib, _ = nnvm.compiler.build(
y, target,
shape={"x": shape},
dtype={"x": dtype})
temp = util.tempdir()
path_dso = temp.relpath("dev_lib.o")
lib.save(path_dso)
remote.upload(path_dso)
lib = remote.load_module("dev_lib.o")
a = np.random.randint(0, 256, size=shape).astype(dtype)
a[:] = 1
params = {"x" : a}
ctx = remote.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
# uses save param_dict
m.load_params(nnvm.compiler.save_param_dict(params))
m.run()
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype, ctx=ctx))
np.testing.assert_allclose(a + 1, out.asnumpy())
print("Test RPC connection to PowerPC...")
remote = rpc.connect(host, port)
target = "llvm -mtriple=powerpc-linux-gnu"
for dtype in ["float32", "float64", "int32", "int8"]:
verify_nnvm(remote, target, (10,), dtype)
if __name__ == "__main__": if __name__ == "__main__":
test_save_load() test_save_load()
test_bigendian_rpc_param()
...@@ -93,6 +93,7 @@ class VPIDeviceAPI final : public runtime::DeviceAPI { ...@@ -93,6 +93,7 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
if (static_cast<int>(ctx_from.device_type) == kDLVPI) { if (static_cast<int>(ctx_from.device_type) == kDLVPI) {
from = RealAddr(static_cast<const char*>(from) + from_offset, size); from = RealAddr(static_cast<const char*>(from) + from_offset, size);
......
...@@ -434,7 +434,7 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -434,7 +434,7 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
DeviceAPIManager::Get(ctx)->CopyDataFromTo( DeviceAPIManager::Get(ctx)->CopyDataFromTo(
from->data, static_cast<size_t>(from->byte_offset), from->data, static_cast<size_t>(from->byte_offset),
to->data, static_cast<size_t>(to->byte_offset), to->data, static_cast<size_t>(to->byte_offset),
from_size, from->ctx, to->ctx, stream); from_size, from->ctx, to->ctx, from->dtype, stream);
API_END(); API_END();
} }
...@@ -452,7 +452,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle, ...@@ -452,7 +452,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle,
DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo( DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo(
data, 0, data, 0,
handle->data, static_cast<size_t>(handle->byte_offset), handle->data, static_cast<size_t>(handle->byte_offset),
nbytes, cpu_ctx, handle->ctx, nullptr); nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr);
API_END(); API_END();
} }
...@@ -469,7 +469,7 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle, ...@@ -469,7 +469,7 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo( DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo(
handle->data, static_cast<size_t>(handle->byte_offset), handle->data, static_cast<size_t>(handle->byte_offset),
data, 0, data, 0,
nbytes, handle->ctx, cpu_ctx, nullptr); nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr);
API_END(); API_END();
} }
......
...@@ -53,6 +53,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -53,6 +53,7 @@ class CPUDeviceAPI final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
memcpy(static_cast<char*>(to) + to_offset, memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset, static_cast<const char*>(from) + from_offset,
......
...@@ -99,6 +99,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -99,6 +99,7 @@ class CUDADeviceAPI final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream); cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
......
...@@ -75,6 +75,7 @@ class MetalWorkspace final : public DeviceAPI { ...@@ -75,6 +75,7 @@ class MetalWorkspace final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final; void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
......
...@@ -158,6 +158,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from, ...@@ -158,6 +158,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) { TVMStreamHandle stream) {
this->Init(); this->Init();
CHECK(stream == nullptr); CHECK(stream == nullptr);
......
...@@ -154,6 +154,7 @@ class OpenCLWorkspace final : public DeviceAPI { ...@@ -154,6 +154,7 @@ class OpenCLWorkspace final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final; void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
......
...@@ -110,6 +110,7 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from, ...@@ -110,6 +110,7 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) { TVMStreamHandle stream) {
this->Init(); this->Init();
CHECK(stream == nullptr); CHECK(stream == nullptr);
......
...@@ -119,6 +119,7 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from, ...@@ -119,6 +119,7 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from,
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) { TVMStreamHandle stream) {
CHECK(stream == nullptr); CHECK(stream == nullptr);
......
...@@ -81,6 +81,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -81,6 +81,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
hipStream_t hip_stream = static_cast<hipStream_t>(stream); hipStream_t hip_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
......
...@@ -49,6 +49,7 @@ class RPCDeviceAPI final : public DeviceAPI { ...@@ -49,6 +49,7 @@ class RPCDeviceAPI final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
int from_dev_type = ctx_from.device_type; int from_dev_type = ctx_from.device_type;
int to_dev_type = ctx_to.device_type; int to_dev_type = ctx_to.device_type;
...@@ -60,19 +61,18 @@ class RPCDeviceAPI final : public DeviceAPI { ...@@ -60,19 +61,18 @@ class RPCDeviceAPI final : public DeviceAPI {
RPCCode::kCopyAmongRemote, RPCCode::kCopyAmongRemote,
static_cast<const RemoteSpace*>(from)->data, from_offset, static_cast<const RemoteSpace*>(from)->data, from_offset,
static_cast<const RemoteSpace*>(to)->data, to_offset, static_cast<const RemoteSpace*>(to)->data, to_offset,
size, ctx_from, ctx_to, stream); size, ctx_from, ctx_to, type_hint, stream);
} else if (from_dev_type > kRPCSessMask && } else if (from_dev_type > kRPCSessMask &&
to_dev_type == kDLCPU) { to_dev_type == kDLCPU) {
GetSess(ctx_from)->CopyFromRemote( GetSess(ctx_from)->CopyFromRemote(
static_cast<const RemoteSpace*>(from)->data, from_offset, static_cast<const RemoteSpace*>(from)->data, from_offset,
to, to_offset, size, to, to_offset, size, ctx_from, type_hint);
ctx_from);
} else if (from_dev_type == kDLCPU && } else if (from_dev_type == kDLCPU &&
to_dev_type > kRPCSessMask) { to_dev_type > kRPCSessMask) {
GetSess(ctx_to)->CopyToRemote( GetSess(ctx_to)->CopyToRemote(
(void*)from, from_offset, // NOLINT(*) (void*)from, from_offset, // NOLINT(*)
static_cast<const RemoteSpace*>(to)->data, to_offset, static_cast<const RemoteSpace*>(to)->data, to_offset,
size, ctx_to); size, ctx_to, type_hint);
} else { } else {
LOG(FATAL) << "expect copy from/to remote or between remote"; LOG(FATAL) << "expect copy from/to remote or between remote";
} }
......
...@@ -303,6 +303,7 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -303,6 +303,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
std::string temp_data_; std::string temp_data_;
// Temp variables for copy request state. // Temp variables for copy request state.
TVMContext copy_ctx_; TVMContext copy_ctx_;
TVMType copy_dtype_;
uint64_t copy_handle_, copy_offset_, copy_size_; uint64_t copy_handle_, copy_offset_, copy_size_;
// State switcher // State switcher
void SwitchToState(State state) { void SwitchToState(State state) {
...@@ -344,11 +345,13 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -344,11 +345,13 @@ class RPCSession::EventHandler : public dmlc::Stream {
case kDoCopyFromRemote: { case kDoCopyFromRemote: {
this->RequestBytes(sizeof(uint64_t) * 3); this->RequestBytes(sizeof(uint64_t) * 3);
this->RequestBytes(sizeof(TVMContext)); this->RequestBytes(sizeof(TVMContext));
this->RequestBytes(sizeof(TVMType));
break; break;
} }
case kDoCopyToRemote: { case kDoCopyToRemote: {
this->RequestBytes(sizeof(uint64_t) * 3); this->RequestBytes(sizeof(uint64_t) * 3);
this->RequestBytes(sizeof(TVMContext)); this->RequestBytes(sizeof(TVMContext));
this->RequestBytes(sizeof(TVMType));
break; break;
} }
case kCopyAckReceived: case kCopyAckReceived:
...@@ -554,18 +557,30 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -554,18 +557,30 @@ class RPCSession::EventHandler : public dmlc::Stream {
} }
void HandleCopyFromRemote() { void HandleCopyFromRemote() {
uint64_t handle, offset, size; uint64_t handle, offset, num_bytes;
TVMContext ctx; TVMContext ctx;
TVMType type_hint;
this->Read(&handle); this->Read(&handle);
this->Read(&offset); this->Read(&offset);
this->Read(&size); this->Read(&num_bytes);
this->Read(&ctx); this->Read(&ctx);
this->Read(&type_hint);
size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;
if (ctx.device_type == kDLCPU) { if (ctx.device_type == kDLCPU) {
RPCCode code = RPCCode::kCopyAck; RPCCode code = RPCCode::kCopyAck;
this->Write(code); this->Write(code);
this->WriteArray(reinterpret_cast<char*>(handle) + offset, size); char* dptr = reinterpret_cast<char*>(handle) + offset;
if (!DMLC_IO_NO_ENDIAN_SWAP) {
temp_data_.resize(0);
temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes);
dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes);
this->WriteArray(temp_data_.data(), num_bytes);
} else {
this->WriteArray(dptr, num_bytes);
}
} else { } else {
temp_data_.resize(size + 1); temp_data_.resize(num_bytes + 1);
try { try {
TVMContext cpu_ctx; TVMContext cpu_ctx;
cpu_ctx.device_type = kDLCPU; cpu_ctx.device_type = kDLCPU;
...@@ -573,10 +588,13 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -573,10 +588,13 @@ class RPCSession::EventHandler : public dmlc::Stream {
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
reinterpret_cast<void*>(handle), offset, reinterpret_cast<void*>(handle), offset,
dmlc::BeginPtr(temp_data_), 0, dmlc::BeginPtr(temp_data_), 0,
size, ctx, cpu_ctx, nullptr); num_bytes, ctx, cpu_ctx, type_hint, nullptr);
RPCCode code = RPCCode::kCopyAck; RPCCode code = RPCCode::kCopyAck;
this->Write(code); this->Write(code);
this->WriteArray(&temp_data_[0], size); if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes);
}
this->WriteArray(&temp_data_[0], num_bytes);
} catch (const std::runtime_error &e) { } catch (const std::runtime_error &e) {
RPCCode code = RPCCode::kException; RPCCode code = RPCCode::kException;
this->Write(code); this->Write(code);
...@@ -597,6 +615,7 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -597,6 +615,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
CHECK(this->Read(&copy_offset_)); CHECK(this->Read(&copy_offset_));
CHECK(this->Read(&copy_size_)); CHECK(this->Read(&copy_size_));
CHECK(this->Read(&copy_ctx_)); CHECK(this->Read(&copy_ctx_));
CHECK(this->Read(&copy_dtype_));
arg_recv_stage_ = 1; arg_recv_stage_ = 1;
CHECK_EQ(pending_request_bytes_, 0U); CHECK_EQ(pending_request_bytes_, 0U);
this->RequestBytes(copy_size_); this->RequestBytes(copy_size_);
...@@ -607,12 +626,20 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -607,12 +626,20 @@ class RPCSession::EventHandler : public dmlc::Stream {
int ret_tcode = kNull; int ret_tcode = kNull;
RPCCode code = RPCCode::kReturn; RPCCode code = RPCCode::kReturn;
std::string errmsg; std::string errmsg;
size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8;
if (copy_ctx_.device_type == kDLCPU) { if (copy_ctx_.device_type == kDLCPU) {
this->ReadArray( char* dptr = reinterpret_cast<char*>(copy_handle_) + copy_offset_;
reinterpret_cast<char*>(copy_handle_) + copy_offset_, copy_size_); this->ReadArray(dptr, copy_size_);
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes);
}
} else { } else {
temp_data_.resize(copy_size_ + 1); temp_data_.resize(copy_size_ + 1);
this->ReadArray(&temp_data_[0], copy_size_); this->ReadArray(&temp_data_[0], copy_size_);
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes);
}
try { try {
TVMContext cpu_ctx; TVMContext cpu_ctx;
cpu_ctx.device_type = kDLCPU; cpu_ctx.device_type = kDLCPU;
...@@ -620,7 +647,7 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -620,7 +647,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
DeviceAPI::Get(copy_ctx_)->CopyDataFromTo( DeviceAPI::Get(copy_ctx_)->CopyDataFromTo(
temp_data_.data(), 0, temp_data_.data(), 0,
reinterpret_cast<void*>(copy_handle_), copy_offset_, reinterpret_cast<void*>(copy_handle_), copy_offset_,
copy_size_, cpu_ctx, copy_ctx_, nullptr); copy_size_, cpu_ctx, copy_ctx_, copy_dtype_, nullptr);
} catch (const std::runtime_error &e) { } catch (const std::runtime_error &e) {
code = RPCCode::kException; code = RPCCode::kException;
errmsg = e.what(); errmsg = e.what();
...@@ -873,7 +900,8 @@ void RPCSession::CopyToRemote(void* from, ...@@ -873,7 +900,8 @@ void RPCSession::CopyToRemote(void* from,
void* to, void* to,
size_t to_offset, size_t to_offset,
size_t data_size, size_t data_size,
TVMContext ctx_to) { TVMContext ctx_to,
TVMType type_hint) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
ctx_to = handler_->StripSessMask(ctx_to); ctx_to = handler_->StripSessMask(ctx_to);
RPCCode code = RPCCode::kCopyToRemote; RPCCode code = RPCCode::kCopyToRemote;
...@@ -885,6 +913,7 @@ void RPCSession::CopyToRemote(void* from, ...@@ -885,6 +913,7 @@ void RPCSession::CopyToRemote(void* from,
uint64_t size = static_cast<uint64_t>(data_size); uint64_t size = static_cast<uint64_t>(data_size);
handler_->Write(size); handler_->Write(size);
handler_->Write(ctx_to); handler_->Write(ctx_to);
handler_->Write(type_hint);
handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size); handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size);
TVMRetValue rv; TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn); CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn);
...@@ -895,7 +924,8 @@ void RPCSession::CopyFromRemote(void* from, ...@@ -895,7 +924,8 @@ void RPCSession::CopyFromRemote(void* from,
void* to, void* to,
size_t to_offset, size_t to_offset,
size_t data_size, size_t data_size,
TVMContext ctx_from) { TVMContext ctx_from,
TVMType type_hint) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
ctx_from = handler_->StripSessMask(ctx_from); ctx_from = handler_->StripSessMask(ctx_from);
RPCCode code = RPCCode::kCopyFromRemote; RPCCode code = RPCCode::kCopyFromRemote;
...@@ -907,6 +937,7 @@ void RPCSession::CopyFromRemote(void* from, ...@@ -907,6 +937,7 @@ void RPCSession::CopyFromRemote(void* from,
uint64_t size = static_cast<uint64_t>(data_size); uint64_t size = static_cast<uint64_t>(data_size);
handler_->Write(size); handler_->Write(size);
handler_->Write(ctx_from); handler_->Write(ctx_from);
handler_->Write(type_hint);
TVMRetValue rv; TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck); CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck);
reader_.Reserve(data_size); reader_.Reserve(data_size);
...@@ -996,7 +1027,8 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) { ...@@ -996,7 +1027,8 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
uint64_t size = args[4]; uint64_t size = args[4];
TVMContext ctx_from = args[5]; TVMContext ctx_from = args[5];
TVMContext ctx_to = args[6]; TVMContext ctx_to = args[6];
TVMStreamHandle stream = args[7]; TVMType type_hint = args[7];
TVMStreamHandle stream = args[8];
TVMContext ctx = ctx_from; TVMContext ctx = ctx_from;
if (ctx.device_type == kDLCPU) { if (ctx.device_type == kDLCPU) {
ctx = ctx_to; ctx = ctx_to;
...@@ -1008,7 +1040,7 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) { ...@@ -1008,7 +1040,7 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
from, from_offset, from, from_offset,
to, to_offset, to, to_offset,
size, ctx_from, ctx_to, stream); size, ctx_from, ctx_to, type_hint, stream);
} }
void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
......
...@@ -116,30 +116,34 @@ class RPCSession { ...@@ -116,30 +116,34 @@ class RPCSession {
* \param from_offset The byte offeset in the from. * \param from_offset The byte offeset in the from.
* \param to The target array. * \param to The target array.
* \param to_offset The byte offset in the to. * \param to_offset The byte offset in the to.
* \param size The size of the memory. * \param nbytes The size of the memory in bytes.
* \param ctx_to The target context. * \param ctx_to The target context.
* \param type_hint Hint of content data type.
*/ */
void CopyToRemote(void* from, void CopyToRemote(void* from,
size_t from_offset, size_t from_offset,
void* to, void* to,
size_t to_offset, size_t to_offset,
size_t size, size_t nbytes,
TVMContext ctx_to); TVMContext ctx_to,
TVMType type_hint);
/*! /*!
* \brief Copy bytes from remote array content. * \brief Copy bytes from remote array content.
* \param from The source host data. * \param from The source host data.
* \param from_offset The byte offeset in the from. * \param from_offset The byte offeset in the from.
* \param to The target array. * \param to The target array.
* \param to_offset The byte offset in the to. * \param to_offset The byte offset in the to.
* \param size The size of the memory. * \param nbytes The size of the memory in bytes.
* \param ctx_from The source context. * \param ctx_from The source context.
* \param type_hint Hint of content data type.
*/ */
void CopyFromRemote(void* from, void CopyFromRemote(void* from,
size_t from_offset, size_t from_offset,
void* to, void* to,
size_t to_offset, size_t to_offset,
size_t size, size_t nbytes,
TVMContext ctx_from); TVMContext ctx_from,
TVMType type_hint);
/*! /*!
* \brief Get a remote timer function on ctx. * \brief Get a remote timer function on ctx.
* This function consumes fhandle, caller should not call Free on fhandle. * This function consumes fhandle, caller should not call Free on fhandle.
......
...@@ -141,6 +141,7 @@ class VulkanWorkspace final : public DeviceAPI { ...@@ -141,6 +141,7 @@ class VulkanWorkspace final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final; void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
......
...@@ -131,6 +131,7 @@ void VulkanWorkspace::CopyDataFromTo(const void* from, ...@@ -131,6 +131,7 @@ void VulkanWorkspace::CopyDataFromTo(const void* from,
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) { TVMStreamHandle stream) {
this->Init(); this->Init();
CHECK(stream == nullptr); CHECK(stream == nullptr);
......
import tvm import tvm
import os
import logging import logging
import numpy as np import numpy as np
import time import time
from tvm.contrib import rpc, util from tvm.contrib import rpc, util
def test_bigendian_rpc():
"""Test big endian rpc when there is a PowerPC RPC server available"""
host = os.environ.get("TVM_POWERPC_TEST_HOST", None)
port = os.environ.get("TVM_POWERPC_TEST_PORT", 9090)
if host is None:
return
def verify_rpc(remote, target, shape, dtype):
A = tvm.placeholder(shape, dtype=dtype)
B = tvm.compute(A.shape, lambda i: A[i]+tvm.const(1, A.dtype))
s = tvm.create_schedule(B.op)
f = tvm.build(s, [A, B], target, name="myadd")
ctx = remote.cpu(0)
a = tvm.nd.array(np.random.randint(0, 256, size=shape).astype(A.dtype), ctx=ctx)
b = tvm.nd.array(np.zeros(shape).astype(A.dtype), ctx=ctx)
temp = util.tempdir()
path_dso = temp.relpath("dev_lib.o")
f.save(path_dso)
remote.upload(path_dso)
f = remote.load_module("dev_lib.o")
f(a, b)
np.testing.assert_allclose(a.asnumpy() + 1, b.asnumpy())
print("Test RPC connection to PowerPC...")
remote = rpc.connect(host, port)
target = "llvm -mtriple=powerpc-linux-gnu"
for dtype in ["float32", "float64", "int32", "int8"]:
verify_rpc(remote, target, (10,), dtype)
def test_rpc_simple(): def test_rpc_simple():
if not tvm.module.enabled("rpc"): if not tvm.module.enabled("rpc"):
return return
...@@ -166,6 +198,7 @@ def test_local_func(): ...@@ -166,6 +198,7 @@ def test_local_func():
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
test_bigendian_rpc()
test_rpc_remote_module() test_rpc_remote_module()
test_rpc_return_func() test_rpc_return_func()
test_rpc_file_exchange() test_rpc_file_exchange()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment