Commit 72fa4c1d by Tianqi Chen Committed by GitHub

[NODE][REFLECTION] Support NDArray as field (#1452)

parent 6fbda22d
Subproject commit 9204453ae8de77e7dfc32c4d80f58dd788ad75ff
Subproject commit a5a80bdc8232c9dbfe508bb5c46e8f58cdf7ec20
......@@ -10,6 +10,7 @@
#include <vector>
#include <utility>
#include "./c_runtime_api.h"
#include "./serializer.h"
namespace tvm {
namespace runtime {
......@@ -103,8 +104,25 @@ class NDArray {
* \note The copy may happen asynchrously if it involves a GPU context.
* TVMSynchronize is necessary.
*/
inline void CopyTo(DLTensor* other);
inline void CopyTo(const NDArray& other);
inline void CopyTo(DLTensor* other) const;
inline void CopyTo(const NDArray& other) const;
/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The array under another context.
*/
inline NDArray CopyTo(const DLContext& ctx) const;
/*!
* \brief Load NDArray from stream
* \param stream The input data stream
* \return Whether load is successful
*/
inline bool Load(dmlc::Stream* stream);
/*!
* \brief Save NDArray to stream
* \param stream The output data stream
*/
inline void Save(dmlc::Stream* stream) const;
/*!
* \brief Create a NDArray that shares the data memory with the current one.
* \param shape The shape of the new array.
......@@ -162,6 +180,13 @@ class NDArray {
};
/*!
* \brief Save a DLTensor to stream
* \param strm The outpu stream
* \param tensor The tensor to be saved.
*/
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
/*!
* \brief Reference counted Container object used to back NDArray.
*
* This object is DLTensor compatible:
......@@ -260,17 +285,26 @@ inline void NDArray::CopyFrom(const NDArray& other) {
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor));
}
inline void NDArray::CopyTo(DLTensor* other) {
inline void NDArray::CopyTo(DLTensor* other) const {
CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other);
}
inline void NDArray::CopyTo(const NDArray& other) {
inline void NDArray::CopyTo(const NDArray& other) const {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor));
}
inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx);
this->CopyTo(ret);
return ret;
}
inline int NDArray::use_count() const {
if (data_ == nullptr) return 0;
return data_->ref_counter_.load(std::memory_order_relaxed);
......@@ -280,7 +314,106 @@ inline const DLTensor* NDArray::operator->() const {
return &(data_->dl_tensor);
}
/*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
inline bool SaveDLTensor(dmlc::Stream* strm,
DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
// Always save data as CPU context
//
// Parameters that get serialized should be in CPU by default.
// So even the array's context is GPU, it will be stored as CPU array.
// This is used to prevent case when another user loads the parameters
// back on machine that do not have GPU or related context.
//
// We can always do array.CopyTo(target_ctx) to get a corresponding
// array in the target context.
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
strm->Write(cpu_ctx);
strm->Write(tensor->ndim);
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->WriteArray(tensor->shape, ndim);
int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
if (DMLC_IO_NO_ENDIAN_SWAP &&
tensor->ctx.device_type == kDLCPU &&
tensor->strides == nullptr &&
tensor->byte_offset == 0) {
// quick path
strm->Write(tensor->data, data_byte_size);
} else {
std::vector<uint8_t> bytes(data_byte_size);
CHECK_EQ(TVMArrayCopyToBytes(
tensor, dmlc::BeginPtr(bytes), data_byte_size), 0)
<< TVMGetLastError();
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
}
strm->Write(dmlc::BeginPtr(bytes), data_byte_size);
}
return true;
}
inline void NDArray::Save(dmlc::Stream* strm) const {
SaveDLTensor(strm, const_cast<DLTensor*>(operator->()));
}
inline bool NDArray::Load(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLContext ctx;
int ndim;
DLDataType dtype;
CHECK(strm->Read(&ctx))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&ndim))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&dtype))
<< "Invalid DLTensor file format";
CHECK_EQ(ctx.device_type, kDLCPU)
<< "Invalid DLTensor context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim))
<< "Invalid DLTensor file format";
}
NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
*this = ret;
return true;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_NDARRAY_H_
......@@ -10,6 +10,7 @@
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "./c_runtime_api.h"
#include "./ndarray.h"
namespace dmlc {
namespace serializer {
......
# pylint: disable=invalid-name
"""Helper utility to save parameter dict"""
import ctypes
import tvm
from tvm._ffi.runtime_ctypes import TVMArrayHandle
_save_param_dict = tvm.get_global_func("nnvm.compiler._save_param_dict")
_load_param_dict = tvm.get_global_func("nnvm.compiler._load_param_dict")
......@@ -59,11 +57,5 @@ def load_param_dict(param_bytes):
"""
if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes)
load_mod = _load_param_dict(param_bytes)
size = load_mod(0)
param_dict = {}
for i in range(size):
key = load_mod(1, i)
dltensor_handle = ctypes.cast(load_mod(2, i), TVMArrayHandle)
param_dict[key] = tvm.nd.NDArray(dltensor_handle, False)
return param_dict
load_arr = _load_param_dict(param_bytes)
return {v.name : v.array for v in load_arr}
......@@ -4,10 +4,6 @@
* \brief Interface code with TVM graph runtime.
*/
#include <dmlc/memory_io.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h"
namespace nnvm {
......@@ -37,81 +33,6 @@ NNVM_REGISTER_OP(tvm_op)
return param.num_outputs;
});
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
strm->Write(tensor->ctx);
strm->Write(tensor->ndim);
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->WriteArray(tensor->shape, ndim);
int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
// handle endianness of data correctly.
if (DMLC_IO_NO_ENDIAN_SWAP) {
strm->Write(tensor->data, data_byte_size);
} else {
uint8_t* dptr = reinterpret_cast<uint8_t*>(tensor->data);
std::vector<uint8_t> bytes(dptr, dptr + data_byte_size);
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
strm->Write(dmlc::BeginPtr(bytes), data_byte_size);
}
return true;
}
DLTensor* LoadDLTensor(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLTensor tensor;
CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim);
if (tensor.ndim != 0) {
CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format";
}
DLTensor* ret;
CHECK_EQ(TVMArrayAlloc(shape.data(),
tensor.ndim,
tensor.dtype.code,
tensor.dtype.bits,
tensor.dtype.lanes,
static_cast<int>(tensor.ctx.device_type),
tensor.ctx.device_id,
&ret), 0) << TVMGetLastError();
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
return ret;
}
TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......@@ -136,7 +57,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(sz);
for (size_t i = 0; i < sz; ++i) {
SaveDLTensor(fo, arrays[i]);
tvm::runtime::SaveDLTensor(fo, arrays[i]);
}
}
TVMByteArray arr;
......@@ -149,11 +70,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string bytes = args[0];
std::vector<DLTensor*> data;
std::vector<std::string> names;
dmlc::MemoryStringStream memstrm(&bytes);
dmlc::Stream* strm = &memstrm;
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
......@@ -168,23 +87,19 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
<< "Invalid parameters file format";
tvm::Array<NDArrayWrapper> ret;
for (size_t i = 0; i < size; ++i) {
data.push_back(LoadDLTensor(strm));
tvm::runtime::NDArray temp;
temp.Load(strm);
std::shared_ptr<NDArrayWrapperNode> n
= std::make_shared<NDArrayWrapperNode>();
n->name = std::move(names[i]);
n->array = temp;
ret.push_back(NDArrayWrapper(n));
}
auto packed = [data, names](TVMArgs args, TVMRetValue* rv) {
int code = args[0];
if (code == 0) {
*rv = static_cast<int64_t>(data.size());
} else if (code == 1) {
int index = args[1];
*rv = names[index];
} else {
CHECK_EQ(code, 2);
int index = args[1];
*rv = static_cast<void*>(data[index]);
}
};
*rv = PackedFunc(packed);
*rv = ret;
});
TVM_REGISTER_NODE_TYPE(NDArrayWrapperNode);
} // namespace compiler
} // namespace nnvm
......@@ -7,14 +7,16 @@
#define NNVM_COMPILER_GRAPH_RUNTIME_H_
#include <nnvm/graph.h>
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <vector>
#include <string>
namespace nnvm {
namespace compiler {
/*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
......@@ -32,6 +34,26 @@ struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
}
};
/*!
* \brief wrapper node container for exchange.
*/
struct NDArrayWrapperNode : public ::tvm::Node {
std::string name;
tvm::runtime::NDArray array;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("array", &array);
}
static constexpr const char* _type_key = "NDArrayWrapper";
TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, Node);
};
TVM_DEFINE_NODE_REF(NDArrayWrapper, NDArrayWrapperNode);
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_GRAPH_RUNTIME_H_
......@@ -2,6 +2,9 @@ import os
import numpy as np
import nnvm.compiler
import tvm
import json
import base64
from tvm._ffi.base import py_str
from tvm import rpc
from tvm.contrib import util, graph_runtime
......@@ -20,6 +23,22 @@ def test_save_load():
np.testing.assert_equal(param2["y"].asnumpy(), y)
def test_ndarray_reflection():
x = np.random.uniform(size=(10, 2)).astype("float32")
xx = tvm.nd.array(x)
xnode = tvm.make.node("NDArrayWrapper", name="xx", array=xx)
xnode2 = tvm.make.node("NDArrayWrapper", name="x2", array=xx)
assert xnode.array.same_as(xx)
json_str = tvm.save_json([xnode, xnode2])
json_dict = json.loads(json_str)
b64_str = json_dict["b64ndarrays"][0]
decoded = py_str(base64.b64encode(base64.b64decode(b64_str)))
assert b64_str == decoded
xlist = tvm.load_json(json_str)
np.testing.assert_equal(xlist[0].array.asnumpy(), xx.asnumpy())
assert xlist[1].array == xlist[0].array
def test_bigendian_rpc_param():
"""Test big endian rpc when there is a PowerPC RPC server available"""
host = os.environ.get("TVM_POWERPC_TEST_HOST", None)
......@@ -60,5 +79,6 @@ def test_bigendian_rpc_param():
if __name__ == "__main__":
test_ndarray_reflection()
test_save_load()
test_bigendian_rpc_param()
......@@ -204,6 +204,7 @@ def _handle_return_func(x):
# setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
......
......@@ -23,6 +23,8 @@ cdef inline object make_ret_node(void* chandle):
obj = cls(None)
else:
obj = NodeBase(None)
else:
obj = NodeBase(None)
(<NodeBase>obj).chandle = chandle
return obj
......
......@@ -134,6 +134,32 @@ class NDArrayBase(_NDArrayBase):
"""context of this array"""
return self.ctx
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Check object identity equality
Parameters
----------
other : object
The other object to compare to
Returns
-------
same : bool
Whether other is same as self.
"""
if not isinstance(other, NDArrayBase):
return False
return self.__hash__() == other.__hash__()
def __setitem__(self, in_slice, value):
"""Set ndarray value"""
if (not isinstance(in_slice, slice) or
......
......@@ -32,7 +32,7 @@ using TVMAPINode = std::shared_ptr<Node>;
struct APIAttrGetter : public AttrVisitor {
std::string skey;
TVMRetValue* ret;
bool found_node_ref{false};
bool found_ref_object{false};
void Visit(const char* key, double* value) final {
if (skey == key) *ret = value[0];
......@@ -63,7 +63,13 @@ struct APIAttrGetter : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final {
if (skey == key) {
*ret = value[0];
found_node_ref = true;
found_ref_object = true;
}
}
void Visit(const char* key, runtime::NDArray* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
};
......@@ -98,6 +104,9 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final {
names->push_back(key);
}
void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key);
}
};
class DSLAPIImpl : public DSLAPI {
......@@ -130,7 +139,7 @@ class DSLAPIImpl : public DSLAPI {
*ret_success = 1;
} else {
(*tnode)->VisitAttrs(&getter);
*ret_success = getter.found_node_ref || rv.type_code() != kNull;
*ret_success = getter.found_ref_object || rv.type_code() != kNull;
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) {
TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
......
/*!
* Copyright 2018 by Contributors
*
* \file base64.h
* \brief data stream support to input and output from/to base64 stream
* base64 is easier to store and pass as text format in mapreduce
*/
#ifndef TVM_COMMON_BASE64_H_
#define TVM_COMMON_BASE64_H_
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <cctype>
#include <cstdio>
#include <string>
namespace tvm {
namespace common {
/*! \brief namespace of base64 decoding and encoding table */
namespace base64 {
// decoding table
const char DecodeTable[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
// encoding table
static const char EncodeTable[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
} // namespace base64
/*!
* \brief Buffer reader from stream to avoid
* virtual call overhead on each read.
*/
class StreamBufferReader {
public:
explicit StreamBufferReader(size_t buffer_size) {
buffer_.resize(buffer_size);
}
/*!
* \brief set input stream
* \param stream The stream to be set
*/
void set_stream(dmlc::Stream *stream) {
stream_ = stream;
read_len_ = read_ptr_ = 1;
}
/*!
* \return allows quick read using get char
*/
char GetChar() {
while (true) {
if (read_ptr_ < read_len_) {
return buffer_[read_ptr_++];
} else {
read_len_ = stream_->Read(&buffer_[0], buffer_.length());
if (read_len_ == 0) return EOF;
read_ptr_ = 0;
}
}
}
/*! \return whether we are reaching the end of file */
bool AtEnd() const {
return read_len_ == 0;
}
private:
/*! \brief the underlying stream */
dmlc::Stream *stream_{nullptr};
/*! \brief buffer to hold data */
std::string buffer_;
/*! \brief length of valid data in buffer */
size_t read_len_{1};
/*! \brief pointer in the buffer */
size_t read_ptr_{1};
};
/*!
* \brief Input stream from base64 encoding
*/
class Base64InStream: public dmlc::Stream {
public:
explicit Base64InStream(dmlc::Stream *fs) : reader_(256) {
reader_.set_stream(fs);
}
/*!
* \brief initialize the stream position to beginning of next base64 stream
* \note call this function before actually start read
*/
void InitPosition(void) {
// get a character
do {
temp_ch_ = reader_.GetChar();
} while (isspace(temp_ch_));
}
/*! \brief whether current position is end of a base64 stream */
bool IsEOF(void) const {
return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_));
}
// override read function.
virtual size_t Read(void *ptr, size_t size) {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
size_t tlen = size;
unsigned char *cptr = static_cast<unsigned char*>(ptr);
// if anything left, load from previous buffered result
if (num_prev_ != 0) {
if (num_prev_ == 2) {
if (tlen >= 2) {
*cptr++ = buf_prev[0];
*cptr++ = buf_prev[1];
tlen -= 2;
num_prev_ = 0;
} else {
// assert tlen == 1
*cptr++ = buf_prev[0]; --tlen;
buf_prev[0] = buf_prev[1];
num_prev_ = 1;
}
} else {
// assert num_prev_ == 1
*cptr++ = buf_prev[0]; --tlen; num_prev_ = 0;
}
}
if (tlen == 0) return size;
int nvalue;
// note: everything goes with 4 bytes in Base64
// so we process 4 bytes a unit
while (tlen && temp_ch_ != EOF && !isspace(temp_ch_)) {
// first byte
nvalue = DecodeTable[temp_ch_] << 18;
{
// second byte
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
nvalue |= DecodeTable[temp_ch_] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
}
{
// third byte
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
// handle termination
if (temp_ch_ == '=') {
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ == '=') << "invalid base64 format";
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ == EOF || isspace(temp_ch_))
<< "invalid base64 format";
break;
}
nvalue |= DecodeTable[temp_ch_] << 6;
if (tlen) {
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
} else {
buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF;
}
}
{
// fourth byte
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ != EOF && !isspace(temp_ch_))
<< "invalid base64 format";
if (temp_ch_ == '=') {
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ == EOF || isspace(temp_ch_))
<< "invalid base64 format";
break;
}
nvalue |= DecodeTable[temp_ch_];
if (tlen) {
*cptr++ = nvalue & 0xFF; --tlen;
} else {
buf_prev[num_prev_ ++] = nvalue & 0xFF;
}
}
// get next char
temp_ch_ = reader_.GetChar();
}
if (kStrictCheck) {
CHECK_EQ(tlen, 0) << "Base64InStream: read incomplete";
}
return size - tlen;
}
virtual void Write(const void *ptr, size_t size) {
LOG(FATAL) << "Base64InStream do not support write";
}
private:
// internal reader
StreamBufferReader reader_;
int temp_ch_{0};
int num_prev_{0};
unsigned char buf_prev[2];
// whether we need to do strict check
static const bool kStrictCheck = false;
};
/*!
* \brief Stream to write to base64 format.
*/
class Base64OutStream: public dmlc::Stream {
public:
explicit Base64OutStream(dmlc::Stream *fp) : fp_(fp) {
}
virtual void Write(const void *ptr, size_t size) {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
while (tlen) {
while (buf__top_ < 3 && tlen != 0) {
buf_[++buf__top_] = *cptr++; --tlen;
}
if (buf__top_ == 3) {
// flush 4 bytes out
PutChar(EncodeTable[buf_[1] >> 2]);
PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]);
PutChar(EncodeTable[((buf_[2] << 2) | (buf_[3] >> 6)) & 0x3F]);
PutChar(EncodeTable[buf_[3] & 0x3F]);
buf__top_ = 0;
}
}
}
virtual size_t Read(void *ptr, size_t size) {
LOG(FATAL) << "Base64OutStream do not support read";
return 0;
}
/*!
* \brief finish writing of all current base64 stream, do some post processing
* \param endch character to put to end of stream, if it is EOF, then nothing will be appended.
*/
void Finish(char endch = EOF) {
using base64::EncodeTable;
if (buf__top_ == 1) {
PutChar(EncodeTable[buf_[1] >> 2]);
PutChar(EncodeTable[(buf_[1] << 4) & 0x3F]);
PutChar('=');
PutChar('=');
}
if (buf__top_ == 2) {
PutChar(EncodeTable[buf_[1] >> 2]);
PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]);
PutChar(EncodeTable[(buf_[2] << 2) & 0x3F]);
PutChar('=');
}
buf__top_ = 0;
if (endch != EOF) PutChar(endch);
this->Flush();
}
private:
static constexpr size_t kBufferSize = 256;
dmlc::Stream *fp_{nullptr};
int buf__top_{0};
unsigned char buf_[4];
std::string out_buf_;
void PutChar(char ch) {
out_buf_ += ch;
if (out_buf_.length() >= kBufferSize) Flush();
}
void Flush(void) {
if (out_buf_.length() != 0) {
fp_->Write(&out_buf_[0], out_buf_.length());
out_buf_.clear();
}
}
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_BASE64_H_
......@@ -7,8 +7,11 @@
#include <tvm/expr.h>
#include <tvm/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <string>
#include "../common/base64.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
......@@ -23,6 +26,7 @@ inline std::string Type2String(const Type& t) {
return os.str();
}
inline Type String2Type(std::string s) {
std::istringstream is(s);
halideir_type_code_t code = Type::Int;
......@@ -52,6 +56,8 @@ class NodeIndexer : public AttrVisitor {
public:
std::unordered_map<Node*, size_t> node_index{{nullptr, 0}};
std::vector<Node*> node_list{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index;
std::vector<DLTensor*> tensor_list;
void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {}
......@@ -64,7 +70,13 @@ class NodeIndexer : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final {
MakeIndex(value->node_.get());
}
void Visit(const char* key, runtime::NDArray* value) final {
DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
if (tensor_index.count(ptr)) return;
CHECK_EQ(tensor_index.size(), tensor_list.size());
tensor_index[ptr] = tensor_list.size();
tensor_list.push_back(ptr);
}
// make index of all the children of node
void MakeIndex(Node* node) {
if (node == nullptr) return;
......@@ -140,6 +152,7 @@ struct JSONNode {
class JSONAttrGetter : public AttrVisitor {
public:
const std::unordered_map<Node*, size_t>* node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_;
JSONNode* node_;
void Visit(const char* key, double* value) final {
......@@ -170,6 +183,10 @@ class JSONAttrGetter : public AttrVisitor {
node_->attrs[key] = std::to_string(
node_index_->at(value->node_.get()));
}
void Visit(const char* key, runtime::NDArray* value) final {
node_->attrs[key] = std::to_string(
tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
}
// Get the node
void Get(Node* node) {
if (node == nullptr) {
......@@ -209,6 +226,7 @@ class JSONAttrGetter : public AttrVisitor {
class JSONAttrSetter : public AttrVisitor {
public:
const std::vector<std::shared_ptr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_;
JSONNode* node_;
std::string GetValue(const char* key) const {
......@@ -254,10 +272,16 @@ class JSONAttrSetter : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final {
size_t index;
ParseValue(key, &index);
CHECK_LE(index, node_list_->size());
value->node_ = node_list_->at(index);
}
// Get the node
void Visit(const char* key, runtime::NDArray* value) final {
size_t index;
ParseValue(key, &index);
CHECK_LE(index, tensor_list_->size());
*value = tensor_list_->at(index);
}
// set node to be current JSONNode
void Set(Node* node) {
if (node == nullptr) return;
if (node->is_type<ArrayNode>()) {
......@@ -292,6 +316,8 @@ struct JSONGraph {
size_t root;
// the nodes of the graph
std::vector<JSONNode> nodes;
// base64 b64ndarrays of arrays
std::vector<std::string> b64ndarrays;
// global attributes
AttrMap attrs;
......@@ -299,6 +325,7 @@ struct JSONGraph {
writer->BeginObject();
writer->WriteObjectKeyValue("root", root);
writer->WriteObjectKeyValue("nodes", nodes);
writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays);
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
}
......@@ -310,6 +337,7 @@ struct JSONGraph {
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("root", &root);
helper.DeclareField("nodes", &nodes);
helper.DeclareOptionalField("b64ndarrays", &b64ndarrays);
helper.DeclareOptionalField("attrs", &attrs);
helper.ReadAllFields(reader);
}
......@@ -320,6 +348,7 @@ struct JSONGraph {
indexer.MakeIndex(root.node_.get());
JSONAttrGetter getter;
getter.node_index_ = &indexer.node_index;
getter.tensor_index_ = &indexer.tensor_index;
for (Node* n : indexer.node_list) {
JSONNode jnode;
getter.node_ = &jnode;
......@@ -328,6 +357,15 @@ struct JSONGraph {
}
g.attrs["tvm_version"] = TVM_VERSION;
g.root = indexer.node_index.at(root.node_.get());
// serialize tensor
for (DLTensor* tensor : indexer.tensor_list) {
std::string blob;
dmlc::MemoryStringStream mstrm(&blob);
common::Base64OutStream b64strm(&mstrm);
runtime::SaveDLTensor(&b64strm, tensor);
b64strm.Finish();
g.b64ndarrays.emplace_back(std::move(blob));
}
return g;
}
};
......@@ -347,6 +385,16 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
// load in json graph.
jgraph.Load(&reader);
std::vector<std::shared_ptr<Node> > nodes;
std::vector<runtime::NDArray> tensors;
// load in tensors
for (const std::string& blob : jgraph.b64ndarrays) {
dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
common::Base64InStream b64strm(&mstrm);
b64strm.InitPosition();
runtime::NDArray temp;
CHECK(temp.Load(&b64strm));
tensors.emplace_back(temp);
}
// node 0 is always null
nodes.reserve(jgraph.nodes.size());
for (const JSONNode& jnode : jgraph.nodes) {
......@@ -362,6 +410,7 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
CHECK_EQ(nodes.size(), jgraph.nodes.size());
JSONAttrSetter setter;
setter.node_list_ = &nodes;
setter.tensor_list_ = &tensors;
for (size_t i = 0; i < nodes.size(); ++i) {
setter.node_ = &jgraph.nodes[i];
......@@ -402,6 +451,9 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final {
*value = GetAttr(key).operator NodeRef();
}
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
private:
runtime::TVMArgValue GetAttr(const char* key) {
......
......@@ -4,7 +4,7 @@
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <tvm/runtime/ndarray.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <numeric>
......@@ -399,52 +399,9 @@ class GraphRuntime : public ModuleNode {
void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
// always use strm->Read to maintain endianness conversion
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLTensor tensor;
CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim);
if (tensor.ndim != 0) {
CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format";
}
CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch";
CHECK(tensor.dtype.bits == dst->dtype.bits &&
tensor.dtype.code == dst->dtype.code &&
tensor.dtype.lanes == dst->dtype.lanes) << "param type mismatch";
for (int i = 0; i < tensor.ndim; ++i) {
CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch";
}
size_t bits = dst->dtype.bits * dst->dtype.lanes;
size_t elem_bytes = (bits + 7) / 8;
size_t num_elems = 1;
for (int i = 0; i < dst->ndim; ++i) {
num_elems *= dst->shape[i];
}
uint64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK_EQ(data_byte_size, elem_bytes * num_elems)
<< "Invalid DLTensor file format";
std::vector<uint8_t> bytes(data_byte_size + 1);
CHECK(strm->Read(&bytes[0], data_byte_size))
<< "Invalid DLTensor file format";
// explicitly swap endian when necessary.
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(&bytes[0], elem_bytes, num_elems);
}
TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size));
NDArray temp;
temp.Load(strm);
temp.CopyTo(dst);
}
void GraphRuntime::LoadParams(dmlc::Stream* strm) {
......
......@@ -13,8 +13,6 @@
namespace tvm {
namespace runtime {
/*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment