Commit 72fa4c1d by Tianqi Chen Committed by GitHub

[NODE][REFLECTION] Support NDArray as field (#1452)

parent 6fbda22d
Subproject commit 9204453ae8de77e7dfc32c4d80f58dd788ad75ff Subproject commit a5a80bdc8232c9dbfe508bb5c46e8f58cdf7ec20
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include "./c_runtime_api.h" #include "./c_runtime_api.h"
#include "./serializer.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -103,8 +104,25 @@ class NDArray { ...@@ -103,8 +104,25 @@ class NDArray {
* \note The copy may happen asynchrously if it involves a GPU context. * \note The copy may happen asynchrously if it involves a GPU context.
* TVMSynchronize is necessary. * TVMSynchronize is necessary.
*/ */
inline void CopyTo(DLTensor* other); inline void CopyTo(DLTensor* other) const;
inline void CopyTo(const NDArray& other); inline void CopyTo(const NDArray& other) const;
/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The array under another context.
*/
inline NDArray CopyTo(const DLContext& ctx) const;
/*!
* \brief Load NDArray from stream
* \param stream The input data stream
* \return Whether load is successful
*/
inline bool Load(dmlc::Stream* stream);
/*!
* \brief Save NDArray to stream
* \param stream The output data stream
*/
inline void Save(dmlc::Stream* stream) const;
/*! /*!
* \brief Create a NDArray that shares the data memory with the current one. * \brief Create a NDArray that shares the data memory with the current one.
* \param shape The shape of the new array. * \param shape The shape of the new array.
...@@ -162,6 +180,13 @@ class NDArray { ...@@ -162,6 +180,13 @@ class NDArray {
}; };
/*! /*!
* \brief Save a DLTensor to stream
* \param strm The outpu stream
* \param tensor The tensor to be saved.
*/
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
/*!
* \brief Reference counted Container object used to back NDArray. * \brief Reference counted Container object used to back NDArray.
* *
* This object is DLTensor compatible: * This object is DLTensor compatible:
...@@ -260,17 +285,26 @@ inline void NDArray::CopyFrom(const NDArray& other) { ...@@ -260,17 +285,26 @@ inline void NDArray::CopyFrom(const NDArray& other) {
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor)); CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor));
} }
inline void NDArray::CopyTo(DLTensor* other) { inline void NDArray::CopyTo(DLTensor* other) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other); CopyFromTo(&(data_->dl_tensor), other);
} }
inline void NDArray::CopyTo(const NDArray& other) { inline void NDArray::CopyTo(const NDArray& other) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr); CHECK(other.data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor)); CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor));
} }
inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx);
this->CopyTo(ret);
return ret;
}
inline int NDArray::use_count() const { inline int NDArray::use_count() const {
if (data_ == nullptr) return 0; if (data_ == nullptr) return 0;
return data_->ref_counter_.load(std::memory_order_relaxed); return data_->ref_counter_.load(std::memory_order_relaxed);
...@@ -280,7 +314,106 @@ inline const DLTensor* NDArray::operator->() const { ...@@ -280,7 +314,106 @@ inline const DLTensor* NDArray::operator->() const {
return &(data_->dl_tensor); return &(data_->dl_tensor);
} }
/*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
inline bool SaveDLTensor(dmlc::Stream* strm,
DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
// Always save data as CPU context
//
// Parameters that get serialized should be in CPU by default.
// So even the array's context is GPU, it will be stored as CPU array.
// This is used to prevent case when another user loads the parameters
// back on machine that do not have GPU or related context.
//
// We can always do array.CopyTo(target_ctx) to get a corresponding
// array in the target context.
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
strm->Write(cpu_ctx);
strm->Write(tensor->ndim);
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->WriteArray(tensor->shape, ndim);
int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
if (DMLC_IO_NO_ENDIAN_SWAP &&
tensor->ctx.device_type == kDLCPU &&
tensor->strides == nullptr &&
tensor->byte_offset == 0) {
// quick path
strm->Write(tensor->data, data_byte_size);
} else {
std::vector<uint8_t> bytes(data_byte_size);
CHECK_EQ(TVMArrayCopyToBytes(
tensor, dmlc::BeginPtr(bytes), data_byte_size), 0)
<< TVMGetLastError();
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
}
strm->Write(dmlc::BeginPtr(bytes), data_byte_size);
}
return true;
}
inline void NDArray::Save(dmlc::Stream* strm) const {
SaveDLTensor(strm, const_cast<DLTensor*>(operator->()));
}
inline bool NDArray::Load(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLContext ctx;
int ndim;
DLDataType dtype;
CHECK(strm->Read(&ctx))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&ndim))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&dtype))
<< "Invalid DLTensor file format";
CHECK_EQ(ctx.device_type, kDLCPU)
<< "Invalid DLTensor context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim))
<< "Invalid DLTensor file format";
}
NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
*this = ret;
return true;
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_NDARRAY_H_ #endif // TVM_RUNTIME_NDARRAY_H_
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include "./c_runtime_api.h" #include "./c_runtime_api.h"
#include "./ndarray.h"
namespace dmlc { namespace dmlc {
namespace serializer { namespace serializer {
......
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Helper utility to save parameter dict""" """Helper utility to save parameter dict"""
import ctypes
import tvm import tvm
from tvm._ffi.runtime_ctypes import TVMArrayHandle
_save_param_dict = tvm.get_global_func("nnvm.compiler._save_param_dict") _save_param_dict = tvm.get_global_func("nnvm.compiler._save_param_dict")
_load_param_dict = tvm.get_global_func("nnvm.compiler._load_param_dict") _load_param_dict = tvm.get_global_func("nnvm.compiler._load_param_dict")
...@@ -59,11 +57,5 @@ def load_param_dict(param_bytes): ...@@ -59,11 +57,5 @@ def load_param_dict(param_bytes):
""" """
if isinstance(param_bytes, (bytes, str)): if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes) param_bytes = bytearray(param_bytes)
load_mod = _load_param_dict(param_bytes) load_arr = _load_param_dict(param_bytes)
size = load_mod(0) return {v.name : v.array for v in load_arr}
param_dict = {}
for i in range(size):
key = load_mod(1, i)
dltensor_handle = ctypes.cast(load_mod(2, i), TVMArrayHandle)
param_dict[key] = tvm.nd.NDArray(dltensor_handle, False)
return param_dict
...@@ -4,10 +4,6 @@ ...@@ -4,10 +4,6 @@
* \brief Interface code with TVM graph runtime. * \brief Interface code with TVM graph runtime.
*/ */
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h" #include "./graph_runtime.h"
namespace nnvm { namespace nnvm {
...@@ -37,81 +33,6 @@ NNVM_REGISTER_OP(tvm_op) ...@@ -37,81 +33,6 @@ NNVM_REGISTER_OP(tvm_op)
return param.num_outputs; return param.num_outputs;
}); });
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
strm->Write(tensor->ctx);
strm->Write(tensor->ndim);
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->WriteArray(tensor->shape, ndim);
int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
// handle endianness of data correctly.
if (DMLC_IO_NO_ENDIAN_SWAP) {
strm->Write(tensor->data, data_byte_size);
} else {
uint8_t* dptr = reinterpret_cast<uint8_t*>(tensor->data);
std::vector<uint8_t> bytes(dptr, dptr + data_byte_size);
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
strm->Write(dmlc::BeginPtr(bytes), data_byte_size);
}
return true;
}
DLTensor* LoadDLTensor(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLTensor tensor;
CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim);
if (tensor.ndim != 0) {
CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format";
}
DLTensor* ret;
CHECK_EQ(TVMArrayAlloc(shape.data(),
tensor.ndim,
tensor.dtype.code,
tensor.dtype.bits,
tensor.dtype.lanes,
static_cast<int>(tensor.ctx.device_type),
tensor.ctx.device_id,
&ret), 0) << TVMGetLastError();
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
return ret;
}
TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict") TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
...@@ -136,7 +57,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict") ...@@ -136,7 +57,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
uint64_t sz = static_cast<uint64_t>(arrays.size()); uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(sz); fo->Write(sz);
for (size_t i = 0; i < sz; ++i) { for (size_t i = 0; i < sz; ++i) {
SaveDLTensor(fo, arrays[i]); tvm::runtime::SaveDLTensor(fo, arrays[i]);
} }
} }
TVMByteArray arr; TVMByteArray arr;
...@@ -149,11 +70,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict") ...@@ -149,11 +70,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
std::string bytes = args[0]; std::string bytes = args[0];
std::vector<DLTensor*> data;
std::vector<std::string> names; std::vector<std::string> names;
dmlc::MemoryStringStream memstrm(&bytes); dmlc::MemoryStringStream memstrm(&bytes);
dmlc::Stream* strm = &memstrm; dmlc::Stream* strm = &memstrm;
uint64_t header, reserved; uint64_t header, reserved;
CHECK(strm->Read(&header)) CHECK(strm->Read(&header))
<< "Invalid parameters file format"; << "Invalid parameters file format";
...@@ -168,23 +87,19 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") ...@@ -168,23 +87,19 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
size_t size = static_cast<size_t>(sz); size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) CHECK(size == names.size())
<< "Invalid parameters file format"; << "Invalid parameters file format";
tvm::Array<NDArrayWrapper> ret;
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
data.push_back(LoadDLTensor(strm)); tvm::runtime::NDArray temp;
} temp.Load(strm);
auto packed = [data, names](TVMArgs args, TVMRetValue* rv) { std::shared_ptr<NDArrayWrapperNode> n
int code = args[0]; = std::make_shared<NDArrayWrapperNode>();
if (code == 0) { n->name = std::move(names[i]);
*rv = static_cast<int64_t>(data.size()); n->array = temp;
} else if (code == 1) { ret.push_back(NDArrayWrapper(n));
int index = args[1];
*rv = names[index];
} else {
CHECK_EQ(code, 2);
int index = args[1];
*rv = static_cast<void*>(data[index]);
} }
}; *rv = ret;
*rv = PackedFunc(packed);
}); });
TVM_REGISTER_NODE_TYPE(NDArrayWrapperNode);
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
...@@ -7,14 +7,16 @@ ...@@ -7,14 +7,16 @@
#define NNVM_COMPILER_GRAPH_RUNTIME_H_ #define NNVM_COMPILER_GRAPH_RUNTIME_H_
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <vector> #include <vector>
#include <string> #include <string>
namespace nnvm { namespace nnvm {
namespace compiler { namespace compiler {
/*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
/*! \brief Magic number for NDArray list file */ /*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
...@@ -32,6 +34,26 @@ struct TVMOpParam : public dmlc::Parameter<TVMOpParam> { ...@@ -32,6 +34,26 @@ struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
} }
}; };
/*!
* \brief wrapper node container for exchange.
*/
struct NDArrayWrapperNode : public ::tvm::Node {
std::string name;
tvm::runtime::NDArray array;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("array", &array);
}
static constexpr const char* _type_key = "NDArrayWrapper";
TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, Node);
};
TVM_DEFINE_NODE_REF(NDArrayWrapper, NDArrayWrapperNode);
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
#endif // NNVM_COMPILER_GRAPH_RUNTIME_H_ #endif // NNVM_COMPILER_GRAPH_RUNTIME_H_
...@@ -2,6 +2,9 @@ import os ...@@ -2,6 +2,9 @@ import os
import numpy as np import numpy as np
import nnvm.compiler import nnvm.compiler
import tvm import tvm
import json
import base64
from tvm._ffi.base import py_str
from tvm import rpc from tvm import rpc
from tvm.contrib import util, graph_runtime from tvm.contrib import util, graph_runtime
...@@ -20,6 +23,22 @@ def test_save_load(): ...@@ -20,6 +23,22 @@ def test_save_load():
np.testing.assert_equal(param2["y"].asnumpy(), y) np.testing.assert_equal(param2["y"].asnumpy(), y)
def test_ndarray_reflection():
x = np.random.uniform(size=(10, 2)).astype("float32")
xx = tvm.nd.array(x)
xnode = tvm.make.node("NDArrayWrapper", name="xx", array=xx)
xnode2 = tvm.make.node("NDArrayWrapper", name="x2", array=xx)
assert xnode.array.same_as(xx)
json_str = tvm.save_json([xnode, xnode2])
json_dict = json.loads(json_str)
b64_str = json_dict["b64ndarrays"][0]
decoded = py_str(base64.b64encode(base64.b64decode(b64_str)))
assert b64_str == decoded
xlist = tvm.load_json(json_str)
np.testing.assert_equal(xlist[0].array.asnumpy(), xx.asnumpy())
assert xlist[1].array == xlist[0].array
def test_bigendian_rpc_param(): def test_bigendian_rpc_param():
"""Test big endian rpc when there is a PowerPC RPC server available""" """Test big endian rpc when there is a PowerPC RPC server available"""
host = os.environ.get("TVM_POWERPC_TEST_HOST", None) host = os.environ.get("TVM_POWERPC_TEST_HOST", None)
...@@ -60,5 +79,6 @@ def test_bigendian_rpc_param(): ...@@ -60,5 +79,6 @@ def test_bigendian_rpc_param():
if __name__ == "__main__": if __name__ == "__main__":
test_ndarray_reflection()
test_save_load() test_save_load()
test_bigendian_rpc_param() test_bigendian_rpc_param()
...@@ -204,6 +204,7 @@ def _handle_return_func(x): ...@@ -204,6 +204,7 @@ def _handle_return_func(x):
# setup return handle for function type # setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE) _handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
......
...@@ -23,6 +23,8 @@ cdef inline object make_ret_node(void* chandle): ...@@ -23,6 +23,8 @@ cdef inline object make_ret_node(void* chandle):
obj = cls(None) obj = cls(None)
else: else:
obj = NodeBase(None) obj = NodeBase(None)
else:
obj = NodeBase(None)
(<NodeBase>obj).chandle = chandle (<NodeBase>obj).chandle = chandle
return obj return obj
......
...@@ -134,6 +134,32 @@ class NDArrayBase(_NDArrayBase): ...@@ -134,6 +134,32 @@ class NDArrayBase(_NDArrayBase):
"""context of this array""" """context of this array"""
return self.ctx return self.ctx
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Check object identity equality
Parameters
----------
other : object
The other object to compare to
Returns
-------
same : bool
Whether other is same as self.
"""
if not isinstance(other, NDArrayBase):
return False
return self.__hash__() == other.__hash__()
def __setitem__(self, in_slice, value): def __setitem__(self, in_slice, value):
"""Set ndarray value""" """Set ndarray value"""
if (not isinstance(in_slice, slice) or if (not isinstance(in_slice, slice) or
......
...@@ -32,7 +32,7 @@ using TVMAPINode = std::shared_ptr<Node>; ...@@ -32,7 +32,7 @@ using TVMAPINode = std::shared_ptr<Node>;
struct APIAttrGetter : public AttrVisitor { struct APIAttrGetter : public AttrVisitor {
std::string skey; std::string skey;
TVMRetValue* ret; TVMRetValue* ret;
bool found_node_ref{false}; bool found_ref_object{false};
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final {
if (skey == key) *ret = value[0]; if (skey == key) *ret = value[0];
...@@ -63,7 +63,13 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -63,7 +63,13 @@ struct APIAttrGetter : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final { void Visit(const char* key, NodeRef* value) final {
if (skey == key) { if (skey == key) {
*ret = value[0]; *ret = value[0];
found_node_ref = true; found_ref_object = true;
}
}
void Visit(const char* key, runtime::NDArray* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
} }
} }
}; };
...@@ -98,6 +104,9 @@ struct APIAttrDir : public AttrVisitor { ...@@ -98,6 +104,9 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final { void Visit(const char* key, NodeRef* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key);
}
}; };
class DSLAPIImpl : public DSLAPI { class DSLAPIImpl : public DSLAPI {
...@@ -130,7 +139,7 @@ class DSLAPIImpl : public DSLAPI { ...@@ -130,7 +139,7 @@ class DSLAPIImpl : public DSLAPI {
*ret_success = 1; *ret_success = 1;
} else { } else {
(*tnode)->VisitAttrs(&getter); (*tnode)->VisitAttrs(&getter);
*ret_success = getter.found_node_ref || rv.type_code() != kNull; *ret_success = getter.found_ref_object || rv.type_code() != kNull;
if (rv.type_code() == kStr || if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) { rv.type_code() == kTVMType) {
TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get(); TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
......
/*!
* Copyright 2018 by Contributors
*
* \file base64.h
* \brief data stream support to input and output from/to base64 stream
* base64 is easier to store and pass as text format in mapreduce
*/
#ifndef TVM_COMMON_BASE64_H_
#define TVM_COMMON_BASE64_H_
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <cctype>
#include <cstdio>
#include <string>
namespace tvm {
namespace common {
/*! \brief namespace of base64 decoding and encoding table */
namespace base64 {
// decoding table
const char DecodeTable[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
// encoding table
static const char EncodeTable[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
} // namespace base64
/*!
* \brief Buffer reader from stream to avoid
* virtual call overhead on each read.
*/
class StreamBufferReader {
public:
explicit StreamBufferReader(size_t buffer_size) {
buffer_.resize(buffer_size);
}
/*!
* \brief set input stream
* \param stream The stream to be set
*/
void set_stream(dmlc::Stream *stream) {
stream_ = stream;
read_len_ = read_ptr_ = 1;
}
/*!
* \return allows quick read using get char
*/
char GetChar() {
while (true) {
if (read_ptr_ < read_len_) {
return buffer_[read_ptr_++];
} else {
read_len_ = stream_->Read(&buffer_[0], buffer_.length());
if (read_len_ == 0) return EOF;
read_ptr_ = 0;
}
}
}
/*! \return whether we are reaching the end of file */
bool AtEnd() const {
return read_len_ == 0;
}
private:
/*! \brief the underlying stream */
dmlc::Stream *stream_{nullptr};
/*! \brief buffer to hold data */
std::string buffer_;
/*! \brief length of valid data in buffer */
size_t read_len_{1};
/*! \brief pointer in the buffer */
size_t read_ptr_{1};
};
/*!
* \brief Input stream from base64 encoding
*/
class Base64InStream: public dmlc::Stream {
public:
explicit Base64InStream(dmlc::Stream *fs) : reader_(256) {
reader_.set_stream(fs);
}
/*!
* \brief initialize the stream position to beginning of next base64 stream
* \note call this function before actually start read
*/
void InitPosition(void) {
// get a character
do {
temp_ch_ = reader_.GetChar();
} while (isspace(temp_ch_));
}
/*! \brief whether current position is end of a base64 stream */
bool IsEOF(void) const {
return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_));
}
// override read function.
virtual size_t Read(void *ptr, size_t size) {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
size_t tlen = size;
unsigned char *cptr = static_cast<unsigned char*>(ptr);
// if anything left, load from previous buffered result
if (num_prev_ != 0) {
if (num_prev_ == 2) {
if (tlen >= 2) {
*cptr++ = buf_prev[0];
*cptr++ = buf_prev[1];
tlen -= 2;
num_prev_ = 0;
} else {
// assert tlen == 1
*cptr++ = buf_prev[0]; --tlen;
buf_prev[0] = buf_prev[1];
num_prev_ = 1;
}
} else {
// assert num_prev_ == 1
*cptr++ = buf_prev[0]; --tlen; num_prev_ = 0;
}
}
if (tlen == 0) return size;
int nvalue;
// note: everything goes with 4 bytes in Base64
// so we process 4 bytes a unit
while (tlen && temp_ch_ != EOF && !isspace(temp_ch_)) {
// first byte
nvalue = DecodeTable[temp_ch_] << 18;
{
// second byte
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
nvalue |= DecodeTable[temp_ch_] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
}
{
// third byte
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
// handle termination
if (temp_ch_ == '=') {
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ == '=') << "invalid base64 format";
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ == EOF || isspace(temp_ch_))
<< "invalid base64 format";
break;
}
nvalue |= DecodeTable[temp_ch_] << 6;
if (tlen) {
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
} else {
buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF;
}
}
{
// fourth byte
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ != EOF && !isspace(temp_ch_))
<< "invalid base64 format";
if (temp_ch_ == '=') {
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ == EOF || isspace(temp_ch_))
<< "invalid base64 format";
break;
}
nvalue |= DecodeTable[temp_ch_];
if (tlen) {
*cptr++ = nvalue & 0xFF; --tlen;
} else {
buf_prev[num_prev_ ++] = nvalue & 0xFF;
}
}
// get next char
temp_ch_ = reader_.GetChar();
}
if (kStrictCheck) {
CHECK_EQ(tlen, 0) << "Base64InStream: read incomplete";
}
return size - tlen;
}
virtual void Write(const void *ptr, size_t size) {
LOG(FATAL) << "Base64InStream do not support write";
}
private:
// internal reader
StreamBufferReader reader_;
int temp_ch_{0};
int num_prev_{0};
unsigned char buf_prev[2];
// whether we need to do strict check
static const bool kStrictCheck = false;
};
/*!
* \brief Stream to write to base64 format.
*/
class Base64OutStream: public dmlc::Stream {
public:
explicit Base64OutStream(dmlc::Stream *fp) : fp_(fp) {
}
virtual void Write(const void *ptr, size_t size) {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
while (tlen) {
while (buf__top_ < 3 && tlen != 0) {
buf_[++buf__top_] = *cptr++; --tlen;
}
if (buf__top_ == 3) {
// flush 4 bytes out
PutChar(EncodeTable[buf_[1] >> 2]);
PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]);
PutChar(EncodeTable[((buf_[2] << 2) | (buf_[3] >> 6)) & 0x3F]);
PutChar(EncodeTable[buf_[3] & 0x3F]);
buf__top_ = 0;
}
}
}
virtual size_t Read(void *ptr, size_t size) {
LOG(FATAL) << "Base64OutStream do not support read";
return 0;
}
/*!
* \brief finish writing of all current base64 stream, do some post processing
* \param endch character to put to end of stream, if it is EOF, then nothing will be appended.
*/
void Finish(char endch = EOF) {
using base64::EncodeTable;
if (buf__top_ == 1) {
PutChar(EncodeTable[buf_[1] >> 2]);
PutChar(EncodeTable[(buf_[1] << 4) & 0x3F]);
PutChar('=');
PutChar('=');
}
if (buf__top_ == 2) {
PutChar(EncodeTable[buf_[1] >> 2]);
PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]);
PutChar(EncodeTable[(buf_[2] << 2) & 0x3F]);
PutChar('=');
}
buf__top_ = 0;
if (endch != EOF) PutChar(endch);
this->Flush();
}
private:
static constexpr size_t kBufferSize = 256;
dmlc::Stream *fp_{nullptr};
int buf__top_{0};
unsigned char buf_[4];
std::string out_buf_;
void PutChar(char ch) {
out_buf_ += ch;
if (out_buf_.length() >= kBufferSize) Flush();
}
void Flush(void) {
if (out_buf_.length() != 0) {
fp_->Write(&out_buf_[0], out_buf_.length());
out_buf_.clear();
}
}
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_BASE64_H_
...@@ -7,8 +7,11 @@ ...@@ -7,8 +7,11 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/container.h> #include <tvm/container.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <string> #include <string>
#include "../common/base64.h"
namespace dmlc { namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
...@@ -23,6 +26,7 @@ inline std::string Type2String(const Type& t) { ...@@ -23,6 +26,7 @@ inline std::string Type2String(const Type& t) {
return os.str(); return os.str();
} }
inline Type String2Type(std::string s) { inline Type String2Type(std::string s) {
std::istringstream is(s); std::istringstream is(s);
halideir_type_code_t code = Type::Int; halideir_type_code_t code = Type::Int;
...@@ -52,6 +56,8 @@ class NodeIndexer : public AttrVisitor { ...@@ -52,6 +56,8 @@ class NodeIndexer : public AttrVisitor {
public: public:
std::unordered_map<Node*, size_t> node_index{{nullptr, 0}}; std::unordered_map<Node*, size_t> node_index{{nullptr, 0}};
std::vector<Node*> node_list{nullptr}; std::vector<Node*> node_list{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index;
std::vector<DLTensor*> tensor_list;
void Visit(const char* key, double* value) final {} void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {} void Visit(const char* key, int64_t* value) final {}
...@@ -64,7 +70,13 @@ class NodeIndexer : public AttrVisitor { ...@@ -64,7 +70,13 @@ class NodeIndexer : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final { void Visit(const char* key, NodeRef* value) final {
MakeIndex(value->node_.get()); MakeIndex(value->node_.get());
} }
void Visit(const char* key, runtime::NDArray* value) final {
DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
if (tensor_index.count(ptr)) return;
CHECK_EQ(tensor_index.size(), tensor_list.size());
tensor_index[ptr] = tensor_list.size();
tensor_list.push_back(ptr);
}
// make index of all the children of node // make index of all the children of node
void MakeIndex(Node* node) { void MakeIndex(Node* node) {
if (node == nullptr) return; if (node == nullptr) return;
...@@ -140,6 +152,7 @@ struct JSONNode { ...@@ -140,6 +152,7 @@ struct JSONNode {
class JSONAttrGetter : public AttrVisitor { class JSONAttrGetter : public AttrVisitor {
public: public:
const std::unordered_map<Node*, size_t>* node_index_; const std::unordered_map<Node*, size_t>* node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_;
JSONNode* node_; JSONNode* node_;
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final {
...@@ -170,6 +183,10 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -170,6 +183,10 @@ class JSONAttrGetter : public AttrVisitor {
node_->attrs[key] = std::to_string( node_->attrs[key] = std::to_string(
node_index_->at(value->node_.get())); node_index_->at(value->node_.get()));
} }
void Visit(const char* key, runtime::NDArray* value) final {
node_->attrs[key] = std::to_string(
tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
}
// Get the node // Get the node
void Get(Node* node) { void Get(Node* node) {
if (node == nullptr) { if (node == nullptr) {
...@@ -209,6 +226,7 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -209,6 +226,7 @@ class JSONAttrGetter : public AttrVisitor {
class JSONAttrSetter : public AttrVisitor { class JSONAttrSetter : public AttrVisitor {
public: public:
const std::vector<std::shared_ptr<Node> >* node_list_; const std::vector<std::shared_ptr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_;
JSONNode* node_; JSONNode* node_;
std::string GetValue(const char* key) const { std::string GetValue(const char* key) const {
...@@ -254,10 +272,16 @@ class JSONAttrSetter : public AttrVisitor { ...@@ -254,10 +272,16 @@ class JSONAttrSetter : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final { void Visit(const char* key, NodeRef* value) final {
size_t index; size_t index;
ParseValue(key, &index); ParseValue(key, &index);
CHECK_LE(index, node_list_->size());
value->node_ = node_list_->at(index); value->node_ = node_list_->at(index);
} }
void Visit(const char* key, runtime::NDArray* value) final {
// Get the node size_t index;
ParseValue(key, &index);
CHECK_LE(index, tensor_list_->size());
*value = tensor_list_->at(index);
}
// set node to be current JSONNode
void Set(Node* node) { void Set(Node* node) {
if (node == nullptr) return; if (node == nullptr) return;
if (node->is_type<ArrayNode>()) { if (node->is_type<ArrayNode>()) {
...@@ -292,6 +316,8 @@ struct JSONGraph { ...@@ -292,6 +316,8 @@ struct JSONGraph {
size_t root; size_t root;
// the nodes of the graph // the nodes of the graph
std::vector<JSONNode> nodes; std::vector<JSONNode> nodes;
// base64 b64ndarrays of arrays
std::vector<std::string> b64ndarrays;
// global attributes // global attributes
AttrMap attrs; AttrMap attrs;
...@@ -299,6 +325,7 @@ struct JSONGraph { ...@@ -299,6 +325,7 @@ struct JSONGraph {
writer->BeginObject(); writer->BeginObject();
writer->WriteObjectKeyValue("root", root); writer->WriteObjectKeyValue("root", root);
writer->WriteObjectKeyValue("nodes", nodes); writer->WriteObjectKeyValue("nodes", nodes);
writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays);
if (attrs.size() != 0) { if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs); writer->WriteObjectKeyValue("attrs", attrs);
} }
...@@ -310,6 +337,7 @@ struct JSONGraph { ...@@ -310,6 +337,7 @@ struct JSONGraph {
dmlc::JSONObjectReadHelper helper; dmlc::JSONObjectReadHelper helper;
helper.DeclareField("root", &root); helper.DeclareField("root", &root);
helper.DeclareField("nodes", &nodes); helper.DeclareField("nodes", &nodes);
helper.DeclareOptionalField("b64ndarrays", &b64ndarrays);
helper.DeclareOptionalField("attrs", &attrs); helper.DeclareOptionalField("attrs", &attrs);
helper.ReadAllFields(reader); helper.ReadAllFields(reader);
} }
...@@ -320,6 +348,7 @@ struct JSONGraph { ...@@ -320,6 +348,7 @@ struct JSONGraph {
indexer.MakeIndex(root.node_.get()); indexer.MakeIndex(root.node_.get());
JSONAttrGetter getter; JSONAttrGetter getter;
getter.node_index_ = &indexer.node_index; getter.node_index_ = &indexer.node_index;
getter.tensor_index_ = &indexer.tensor_index;
for (Node* n : indexer.node_list) { for (Node* n : indexer.node_list) {
JSONNode jnode; JSONNode jnode;
getter.node_ = &jnode; getter.node_ = &jnode;
...@@ -328,6 +357,15 @@ struct JSONGraph { ...@@ -328,6 +357,15 @@ struct JSONGraph {
} }
g.attrs["tvm_version"] = TVM_VERSION; g.attrs["tvm_version"] = TVM_VERSION;
g.root = indexer.node_index.at(root.node_.get()); g.root = indexer.node_index.at(root.node_.get());
// serialize tensor
for (DLTensor* tensor : indexer.tensor_list) {
std::string blob;
dmlc::MemoryStringStream mstrm(&blob);
common::Base64OutStream b64strm(&mstrm);
runtime::SaveDLTensor(&b64strm, tensor);
b64strm.Finish();
g.b64ndarrays.emplace_back(std::move(blob));
}
return g; return g;
} }
}; };
...@@ -347,6 +385,16 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) { ...@@ -347,6 +385,16 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
// load in json graph. // load in json graph.
jgraph.Load(&reader); jgraph.Load(&reader);
std::vector<std::shared_ptr<Node> > nodes; std::vector<std::shared_ptr<Node> > nodes;
std::vector<runtime::NDArray> tensors;
// load in tensors
for (const std::string& blob : jgraph.b64ndarrays) {
dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
common::Base64InStream b64strm(&mstrm);
b64strm.InitPosition();
runtime::NDArray temp;
CHECK(temp.Load(&b64strm));
tensors.emplace_back(temp);
}
// node 0 is always null // node 0 is always null
nodes.reserve(jgraph.nodes.size()); nodes.reserve(jgraph.nodes.size());
for (const JSONNode& jnode : jgraph.nodes) { for (const JSONNode& jnode : jgraph.nodes) {
...@@ -362,6 +410,7 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) { ...@@ -362,6 +410,7 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
CHECK_EQ(nodes.size(), jgraph.nodes.size()); CHECK_EQ(nodes.size(), jgraph.nodes.size());
JSONAttrSetter setter; JSONAttrSetter setter;
setter.node_list_ = &nodes; setter.node_list_ = &nodes;
setter.tensor_list_ = &tensors;
for (size_t i = 0; i < nodes.size(); ++i) { for (size_t i = 0; i < nodes.size(); ++i) {
setter.node_ = &jgraph.nodes[i]; setter.node_ = &jgraph.nodes[i];
...@@ -402,6 +451,9 @@ class NodeAttrSetter : public AttrVisitor { ...@@ -402,6 +451,9 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final { void Visit(const char* key, NodeRef* value) final {
*value = GetAttr(key).operator NodeRef(); *value = GetAttr(key).operator NodeRef();
} }
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
private: private:
runtime::TVMArgValue GetAttr(const char* key) { runtime::TVMArgValue GetAttr(const char* key) {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
*/ */
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h> #include <tvm/runtime/ndarray.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <numeric> #include <numeric>
...@@ -399,52 +399,9 @@ class GraphRuntime : public ModuleNode { ...@@ -399,52 +399,9 @@ class GraphRuntime : public ModuleNode {
void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
// always use strm->Read to maintain endianness conversion // always use strm->Read to maintain endianness conversion
uint64_t header, reserved; NDArray temp;
CHECK(strm->Read(&header)) temp.Load(strm);
<< "Invalid DLTensor file format"; temp.CopyTo(dst);
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLTensor tensor;
CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim);
if (tensor.ndim != 0) {
CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format";
}
CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch";
CHECK(tensor.dtype.bits == dst->dtype.bits &&
tensor.dtype.code == dst->dtype.code &&
tensor.dtype.lanes == dst->dtype.lanes) << "param type mismatch";
for (int i = 0; i < tensor.ndim; ++i) {
CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch";
}
size_t bits = dst->dtype.bits * dst->dtype.lanes;
size_t elem_bytes = (bits + 7) / 8;
size_t num_elems = 1;
for (int i = 0; i < dst->ndim; ++i) {
num_elems *= dst->shape[i];
}
uint64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK_EQ(data_byte_size, elem_bytes * num_elems)
<< "Invalid DLTensor file format";
std::vector<uint8_t> bytes(data_byte_size + 1);
CHECK(strm->Read(&bytes[0], data_byte_size))
<< "Invalid DLTensor file format";
// explicitly swap endian when necessary.
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(&bytes[0], elem_bytes, num_elems);
}
TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size));
} }
void GraphRuntime::LoadParams(dmlc::Stream* strm) { void GraphRuntime::LoadParams(dmlc::Stream* strm) {
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
/*! \brief Magic number for NDArray list file */ /*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment