Commit eb8077ff by Siva Committed by Tianqi Chen

[DEBUG] get_node_output : To retrieve out put of any node - for debug purpose. (#820)

parent ba8d00c2
......@@ -151,6 +151,10 @@ ifeq ($(USE_GRAPH_RUNTIME), 1)
RUNTIME_DEP += $(GRAPH_OBJ)
endif
ifeq ($(USE_GRAPH_RUNTIME_DEBUG), 1)
CFLAGS += -DTVM_GRAPH_RUNTIME_DEBUG
endif
include make/contrib/cblas.mk
include make/contrib/random.mk
include make/contrib/nnpack.mk
......
......@@ -50,6 +50,9 @@ USE_RPC = 1
# Whether enable tiny embedded graph runtime.
USE_GRAPH_RUNTIME = 1
# Whether enable additional graph debug functions
USE_GRAPH_RUNTIME_DEBUG = 0
# whether build with LLVM support
# Requires LLVM version >= 4.0
# Set LLVM_CONFIG to your version, uncomment to build with llvm support
......
......@@ -72,6 +72,10 @@ class GraphModule(object):
self._set_input = module["set_input"]
self._run = module["run"]
self._get_output = module["get_output"]
try:
self._debug_get_output = module["debug_get_output"]
except AttributeError:
pass
self._load_params = module["load_params"]
self.ctx = ctx
......@@ -121,6 +125,23 @@ class GraphModule(object):
self._get_output(index, out)
return out
def debug_get_output(self, node, out):
"""Run graph upto node and get the output to out
Parameters
----------
node : int / str
The node index or name
out : NDArray
The output array container
"""
if hasattr(self, '_debug_get_output'):
self._debug_get_output(node, out)
else:
raise RuntimeError("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0")
return out
def load_params(self, params_bytes):
"""Load parameters from serialized byte array of parameter dict.
......
......@@ -107,8 +107,45 @@ class GraphRuntime : public ModuleNode {
uint32_t eid = this->entry_id(outputs_[index]);
TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr));
}
#ifdef TVM_GRAPH_RUNTIME_DEBUG
/*!
* \brief Get the node index given the name of node.
* \param name The name of the node.
* \return The index of node.
*/
int GetNodeIndex(const std::string& name) {
for (uint32_t nid = 0; nid< nodes_.size(); ++nid) {
if (nodes_[nid].name == name) {
return static_cast<int>(nid);
}
}
LOG(FATAL) << "cannot find " << name << " among nodex";
return -1;
}
/*!
* \brief Copy index-th node to data_out.
*
* This method will do a partial run of the the graph
* from begining upto the index-th node and return output of index-th node.
* This is costly operation and suggest to use only for debug porpose.
*
* \param index: The index of the node.
* \param data_out the node data.
*/
void DebugGetNodeOutput(int index, DLTensor* data_out) {
CHECK_LT(static_cast<size_t>(index), nodes_.size());
uint32_t eid = index;
for (size_t i = 0; i < op_execs_.size(); ++i) {
if (static_cast<int>(i) == index) break;
if (op_execs_[i]) op_execs_[i]();
}
TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr));
}
#endif
/*!
* \brief Load parameters from binary stream
* \param strm The input stream.
*/
......@@ -556,6 +593,16 @@ PackedFunc GraphRuntime::GetFunction(
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->GetOutput(args[0], args[1]);
});
#ifdef TVM_GRAPH_RUNTIME_DEBUG
} else if (name == "debug_get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) {
this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
} else {
this->DebugGetNodeOutput(args[0], args[1]);
}
});
#endif
} else if (name == "run") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Run();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment