Commit fd864c51 by Tianqi Chen Committed by GitHub

[RUNTIME] Fix graph runtime for gpu (#491)

parent c468558e
......@@ -72,6 +72,7 @@ class GraphModule(object):
self._set_input = module["set_input"]
self._run = module["run"]
self._get_output = module["get_output"]
self._load_params = module["load_params"]
self.ctx = ctx
def set_input(self, key=None, value=None, **params):
......@@ -120,6 +121,16 @@ class GraphModule(object):
self._get_output(index, out)
return out
def load_params(self, params_bytes):
"""Load parameters from serialized byte array of parameter dict.
Parameters
----------
params_bytes : bytearray
The serialized parameter dict.
"""
self._load_params(bytearray(params_bytes))
def __getitem__(self, key):
"""Get internal module function
......
......@@ -299,7 +299,7 @@ class GraphRuntime : public ModuleNode {
}
CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format";
}
bool LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor);
void LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor);
/*! \brief Setup the temporal storage */
void SetupStorage();
/*! \brief Setup the executors */
......@@ -353,7 +353,7 @@ class GraphRuntime : public ModuleNode {
};
bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header)))
<< "Invalid DLTensor file format";
......@@ -362,30 +362,37 @@ bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->ctx, sizeof(tensor->ctx)))
DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->ndim, sizeof(tensor->ndim)))
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->dtype, sizeof(tensor->dtype)))
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype)))
<< "Invalid DLTensor file format";
int ndim = tensor->ndim;
CHECK(strm->Read(tensor->shape, sizeof(int64_t) * ndim))
std::vector<int64_t> shape(tensor.ndim);
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
<< "Invalid DLTensor file format";
int64_t size = 1;
int type_size = tensor->dtype.bits / 8;
for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i];
CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch";
CHECK(tensor.dtype.bits == dst->dtype.bits &&
tensor.dtype.code == dst->dtype.code &&
tensor.dtype.lanes == dst->dtype.lanes) << "param type mismatch";
for (int i = 0; i < tensor.ndim; ++i) {
CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch";
}
size_t bits = dst->dtype.bits * dst->dtype.lanes;
size_t size = (bits + 7) / 8;
for (int i = 0; i < dst->ndim; ++i) {
size *= dst->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size)))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == type_size * size)
CHECK(data_byte_size == size)
<< "Invalid DLTensor file format";
CHECK(strm->Read(tensor->data, type_size * size))
std::vector<uint8_t> bytes(data_byte_size + 1);
CHECK(strm->Read(&bytes[0], data_byte_size))
<< "Invalid DLTensor file format";
return true;
TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size));
}
void GraphRuntime::LoadParams(dmlc::Stream* strm) {
......@@ -406,11 +413,11 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
uint32_t in_idx = GetInputIndex(names[i]);
CHECK(LoadDLTensor(strm, &data_entry_[this->entry_id(input_nodes_[in_idx], 0)]))
<< "Invalid parameters file format";
uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
CHECK_LT(eid, data_entry_.size());
LoadDLTensor(strm, &data_entry_[eid]);
}
}
......@@ -461,6 +468,7 @@ void GraphRuntime::SetupStorage() {
// Assign the pooled entries.
for (size_t i = 0; i < data_entry_.size(); ++i) {
int storage_id = attrs_.storage_id[i];
CHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size());
data_entry_[i] = *storage_pool_[storage_id];
data_entry_[i].shape = const_cast<int64_t*>(attrs_.shape[i].data());
data_entry_[i].ndim = static_cast<int>(attrs_.shape[i].size());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment