Commit 68ea2c3e by Chris Nuernberger Committed by Tianqi Chen

Small refactor for clarity in arraycopyfromto (#960)

parent 0d6cf0c8
...@@ -413,19 +413,22 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -413,19 +413,22 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
size_t from_size = GetDataSize(from); size_t from_size = GetDataSize(from);
size_t to_size = GetDataSize(to); size_t to_size = GetDataSize(to);
CHECK_EQ(from_size, to_size) CHECK_EQ(from_size, to_size)
<< "TVMArrayCopyFromTo: The size must exactly match"; << "TVMArrayCopyFromTo: The size must exactly match";
TVMContext ctx = from->ctx;
if (ctx.device_type == kDLCPU) { CHECK(from->ctx.device_type == to->ctx.device_type
ctx = to->ctx; || from->ctx.device_type == kDLCPU
} else { || to->ctx.device_type == kDLCPU)
CHECK(to->ctx.device_type == kDLCPU || << "Can not copy across different ctx types directly";
to->ctx.device_type == from->ctx.device_type)
<< "Can not copy across different ctx types directly"; // Use the context that is *not* a cpu context to get the correct device
} // api manager.
TVMContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx;
DeviceAPIManager::Get(ctx)->CopyDataFromTo( DeviceAPIManager::Get(ctx)->CopyDataFromTo(
from->data, static_cast<size_t>(from->byte_offset), from->data, static_cast<size_t>(from->byte_offset),
to->data, static_cast<size_t>(to->byte_offset), to->data, static_cast<size_t>(to->byte_offset),
from_size, from->ctx, to->ctx, stream); from_size, from->ctx, to->ctx, stream);
API_END(); API_END();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment