Unverified Commit 474c70d7 by jmorrill Committed by GitHub

Added CopyFromBytes and CopyToBytes convenience methods to NDArray. Fixed typos. (#4970)

* Added CopyFromBytes and CopyToBytes convenience methods.  Fixed typos.

* Removed unneed argument check

* Use TVMArrayCopyFrom/ToBytes methods

* Moved CopyFrom/ToBytes to ndarray.cc

* CopyToBytes impl was using CopyFromBytes.  Fixed

* changed inline to TVM_DLL

* Used impl from TVMArrayCopyTo/FromBytes into NDArray CopyTo/FromBytes

* Move implementation of all CopyFrom/ToBytes into a common impls

* make arg const

* simplify method impl
parent 2355caa8
...@@ -68,20 +68,38 @@ class NDArray : public ObjectRef { ...@@ -68,20 +68,38 @@ class NDArray : public ObjectRef {
/*! /*!
* \brief Copy data content from another array. * \brief Copy data content from another array.
* \param other The source array to be copied from. * \param other The source array to be copied from.
* \note The copy may happen asynchrously if it involves a GPU context. * \note The copy may happen asynchronously if it involves a GPU context.
* TVMSynchronize is necessary. * TVMSynchronize is necessary.
*/ */
inline void CopyFrom(const DLTensor* other); inline void CopyFrom(const DLTensor* other);
inline void CopyFrom(const NDArray& other); inline void CopyFrom(const NDArray& other);
/*! /*!
* \brief Copy data content from a byte buffer.
* \param data The source bytes to be copied from.
* \param nbytes The size of the buffer in bytes
* Must be equal to the size of the NDArray.
* \note The copy may happen asynchronously if it involves a GPU context.
* TVMSynchronize is necessary.
*/
TVM_DLL void CopyFromBytes(const void* data, size_t nbytes);
/*!
* \brief Copy data content into another array. * \brief Copy data content into another array.
* \param other The source array to be copied from. * \param other The source array to be copied from.
* \note The copy may happen asynchrously if it involves a GPU context. * \note The copy may happen asynchronously if it involves a GPU context.
* TVMSynchronize is necessary. * TVMSynchronize is necessary.
*/ */
inline void CopyTo(DLTensor* other) const; inline void CopyTo(DLTensor* other) const;
inline void CopyTo(const NDArray& other) const; inline void CopyTo(const NDArray& other) const;
/*! /*!
* \brief Copy data content into another array.
* \param data The source bytes to be copied from.
* \param nbytes The size of the data buffer.
* Must be equal to the size of the NDArray.
* \note The copy may happen asynchronously if it involves a GPU context.
* TVMSynchronize is necessary.
*/
TVM_DLL void CopyToBytes(void* data, size_t nbytes) const;
/*!
* \brief Copy the data to another context. * \brief Copy the data to another context.
* \param ctx The target context. * \param ctx The target context.
* \return The array under another context. * \return The array under another context.
...@@ -182,7 +200,7 @@ class NDArray : public ObjectRef { ...@@ -182,7 +200,7 @@ class NDArray : public ObjectRef {
/*! /*!
* \brief Save a DLTensor to stream * \brief Save a DLTensor to stream
* \param strm The outpu stream * \param strm The output stream
* \param tensor The tensor to be saved. * \param tensor The tensor to be saved.
*/ */
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
...@@ -205,7 +223,7 @@ class NDArray::ContainerBase { ...@@ -205,7 +223,7 @@ class NDArray::ContainerBase {
DLTensor dl_tensor; DLTensor dl_tensor;
/*! /*!
* \brief addtional context, reserved for recycling * \brief additional context, reserved for recycling
* \note We can attach additional content here * \note We can attach additional content here
* which the current container depend on * which the current container depend on
* (e.g. reference to original memory when creating views). * (e.g. reference to original memory when creating views).
......
...@@ -60,6 +60,32 @@ inline size_t GetDataAlignment(const DLTensor& arr) { ...@@ -60,6 +60,32 @@ inline size_t GetDataAlignment(const DLTensor& arr) {
return align; return align;
} }
void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) {
TVMContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes)
<< "ArrayCopyFromBytes: size mismatch";
DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
data, 0,
handle->data, static_cast<size_t>(handle->byte_offset),
nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr);
}
void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) {
TVMContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes)
<< "ArrayCopyToBytes: size mismatch";
DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
handle->data, static_cast<size_t>(handle->byte_offset),
data, 0,
nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr);
}
struct NDArray::Internal { struct NDArray::Internal {
// Default deleter for the container // Default deleter for the container
static void DefaultDeleter(Object* ptr_obj) { static void DefaultDeleter(Object* ptr_obj) {
...@@ -185,6 +211,18 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { ...@@ -185,6 +211,18 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
return NDArray(GetObjectPtr<Object>(data)); return NDArray(GetObjectPtr<Object>(data));
} }
void NDArray::CopyToBytes(void* data, size_t nbytes) const {
CHECK(data != nullptr);
CHECK(data_ != nullptr);
ArrayCopyToBytes(&get_mutable()->dl_tensor, data, nbytes);
}
void NDArray::CopyFromBytes(const void* data, size_t nbytes) {
CHECK(data != nullptr);
CHECK(data_ != nullptr);
ArrayCopyFromBytes(&get_mutable()->dl_tensor, data, nbytes);
}
void NDArray::CopyFromTo(const DLTensor* from, void NDArray::CopyFromTo(const DLTensor* from,
DLTensor* to, DLTensor* to,
TVMStreamHandle stream) { TVMStreamHandle stream) {
...@@ -286,16 +324,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle, ...@@ -286,16 +324,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle,
void* data, void* data,
size_t nbytes) { size_t nbytes) {
API_BEGIN(); API_BEGIN();
TVMContext cpu_ctx; ArrayCopyFromBytes(handle, data, nbytes);
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes)
<< "TVMArrayCopyFromBytes: size mismatch";
DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
data, 0,
handle->data, static_cast<size_t>(handle->byte_offset),
nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr);
API_END(); API_END();
} }
...@@ -303,15 +332,6 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle, ...@@ -303,15 +332,6 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
void* data, void* data,
size_t nbytes) { size_t nbytes) {
API_BEGIN(); API_BEGIN();
TVMContext cpu_ctx; ArrayCopyToBytes(handle, data, nbytes);
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes)
<< "TVMArrayCopyToBytes: size mismatch";
DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
handle->data, static_cast<size_t>(handle->byte_offset),
data, 0,
nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr);
API_END(); API_END();
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment