Commit 89d5e552 by Tianqi Chen Committed by GitHub

[RUNTIME] Update graph runtime to rely on smarter planner, add get_input (#990)

parent bcb3bef5
......@@ -72,6 +72,7 @@ class GraphModule(object):
self._set_input = module["set_input"]
self._run = module["run"]
self._get_output = module["get_output"]
self._get_input = module["get_input"]
try:
self._debug_get_output = module["debug_get_output"]
except AttributeError:
......@@ -111,6 +112,20 @@ class GraphModule(object):
self.set_input(**input_dict)
self._run()
def get_input(self, index, out):
"""Get index-th input to out
Parameters
----------
index : int
The input index
out : NDArray
The output array container
"""
self._get_input(index, out)
return out
def get_output(self, index, out):
"""Get index-th output to out
......
......@@ -102,6 +102,16 @@ class GraphRuntime : public ModuleNode {
TVM_CCALL(TVMArrayCopyFromTo(data_in, &data_entry_[eid], nullptr));
}
/*!
* \brief Copy index-th input to data_out
* \param index The input index.
* \param data_out The output
*/
void GetInput(int index, DLTensor* data_out) {
CHECK_LT(static_cast<size_t>(index), input_nodes_.size());
uint32_t eid = this->entry_id(input_nodes_[index], 0);
TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr));
}
/*!
* \brief Copy index-th output to data_out.
* \param index The output index.
* \param data_out the output data.
......@@ -463,14 +473,6 @@ void GraphRuntime::SetupStorage() {
vtype.push_back(tvm::runtime::String2TVMType(s_type));
}
data_entry_.resize(num_node_entries());
// Find the maximum space size.
int max_id = 0;
for (size_t i = 0; i < attrs_.shape.size(); ++i) {
max_id = std::max(attrs_.storage_id[i] + 1, max_id);
}
for (uint32_t nid : input_nodes_) {
attrs_.storage_id[this->entry_id(nid, 0)] = max_id++;
}
// size of each storage pool entry
std::vector<size_t> pool_entry_bytes;
// Find the maximum space size.
......@@ -592,6 +594,14 @@ PackedFunc GraphRuntime::GetFunction(
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->GetOutput(args[0], args[1]);
});
} else if (name == "get_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) {
this->GetInput(this->GetInputIndex(args[0]), args[1]);
} else {
this->GetInput(args[0], args[1]);
}
});
#ifdef TVM_GRAPH_RUNTIME_DEBUG
} else if (name == "debug_get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
......
......@@ -204,7 +204,7 @@ class SpscTaskQueue {
cv_.notify_all();
}
private:
protected:
/*!
* \brief Lock-free enqueue.
* \param input The task to be enqueued.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment