Commit d713d63d by Siju Committed by Tianqi Chen

[DEBUG]Support a debug framework for TVM Runtime (#1378)

parent 74ea8e5f
......@@ -159,6 +159,9 @@ if(USE_GRAPH_RUNTIME)
list(APPEND RUNTIME_SRCS ${RUNTIME_GRAPH_SRCS})
if(USE_GRAPH_RUNTIME_DEBUG)
message(STATUS "Build with Graph runtime debug support...")
file(GLOB RUNTIME_GRAPH_DEBUG_SRCS src/runtime/graph/debug/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_GRAPH_DEBUG_SRCS})
set_source_files_properties(${RUNTIME_GRAPH_SRCS}
PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_DEBUG")
endif(USE_GRAPH_RUNTIME_DEBUG)
......
......@@ -97,6 +97,7 @@ stage('Build') {
echo set\\(USE_SORT ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake
echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(USE_BLAS openblas\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
......@@ -111,6 +112,7 @@ stage('Build') {
echo set\\(USE_OPENCL ON\\) >> config.cmake
echo set\\(USE_ROCM ON\\) >> config.cmake
echo set\\(USE_VULKAN ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER clang-6.0\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
"""
......@@ -127,6 +129,7 @@ stage('Build') {
cd build
cp ../cmake/config.cmake .
echo set\\(USE_SORT ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
......@@ -150,6 +153,7 @@ stage('Build') {
cp ../cmake/config.cmake .
echo set\\(USE_SORT ON\\) >> config.cmake
echo set\\(USE_RPC ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(USE_LLVM llvm-config-5.0\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
......
=================
**Debugger**
=================
TVM Debugger is an interface for debugging TVM's computation graph execution. It helps to provide access to graph structures and tensor values at the TVM runtime.
*******************************************
**Debug Exchange Format**
*******************************************
**1. Computational Graph**
==========================
The optimized graph build by nnvm in json
serialized format is dumped as it is. This contains the whole
information about the graph. The UX can either use this graph directly
or transform this graph to the format UX can understand.
The Graph JSON format is explained below
1. ``nodes``
Nodes are either placeholders or computational nodes in NNVM graph. The nodes are stored
as a list. A node contains the below information
- ``op`` - operation type, ``null`` means it is a placeholder/variable/input node and``tvm_op`` means this node can be executed
- ``name`` - Name of the node
- ``inputs`` - Position of the inputs for this operation, Inputs is a list of tuples with (nodeid, index, version). (Optional)
- ``attrs`` - Attributes of the node which contains the following information
- ``flatten_data`` - Whether this data need to be flattened before execution
- ``func_name`` - Fused function name, corresponds to the symbol in the lib generated by NNVM compilation process.
- ``num_inputs`` - Number of inputs for this node
- ``num_outputs`` - Number of outputs this node produces
2. ``arg_nodes``
arg_nodes is a list of indices of nodes which is placeholder/variable/input or constant/param to the graph.
3. ``heads``
heads is a list of entries as the output of the graph.
4. ``node_row_ptr``
node\_row\_ptr stores the history of forward path, so you can skip constructing the entire graph in inference tasks.
5. ``attrs``
attrs can contain version numbers or similar helpful information.
- ``storage_id`` - Memory slot id for each node in the storage layout.
- ``dtype`` - Datatype of each node (enum value).
- ``dltype`` - Datatype of each node in order.
- ``shape`` - Shape of each node k order.
- ``device_index`` - Device assignment for each entry in the graph.
Example of dumped graph:
::
{
"nodes": [ # List of nodes
{
"op": "null", # operation type = null, this is a placeholder/variable/input or constant/param node
"name": "x", # Name of the argument node
"inputs": [] # inputs for this node, its none since this is an argument node
},
{
"op": "tvm_op", # operation type = tvm_op, this node can be executed
"name": "relu0", # Name of the node
"attrs": { # Attributes of the node
"flatten_data": "0", # Whether this data need to be flattened
"func_name": "fuse_l2_normalize_relu", # Fused function name, corresponds to the symbol in the lib generated by NNVM compilation process
"num_inputs": "1", # Number of inputs for this node
"num_outputs": "1" # Number of outputs this node produces
},
"inputs": [[0, 0, 0]] # Position of the inputs for this operation
}
],
"arg_nodes": [0], # Which all nodes in this are argument nodes
"node_row_ptr": [0, 1, 2], # Row indices for faster depth first search
"heads": [[1, 0, 0]], # Position of the output nodes for this operation
"attrs": { # Attributes for the graph
"storage_id": ["list_int", [1, 0]], # memory slot id for each node in the storage layout
"dtype": ["list_int", [0, 0]], # Datatype of each node (enum value)
"dltype": ["list_str", [ # Datatype of each node in order
"float32",
"float32"]],
"shape": ["list_shape", [ # Shape of each node k order
[1, 3, 20, 20],
[1, 3, 20, 20]]],
"device_index": ["list_int", [1, 1]], # Device assignment for each node in order
}
}
**2. Tensor dumping**
=====================
The tensor received after execution is in ``tvm.ndarray`` type. All the tensors will
be saved as binary bytes in serialized format. The result binary bytes can be loaded by the
API "load_params".
Example of loading the parameters
::
with open(path_params, "rb") as fi:
loaded_params = bytearray(fi.read())
module.load_params(loaded_params)
***************************************
How to use Debugger?
***************************************
1. In ``config.cmake`` set the ``USE_GRAPH_RUNTIME_DEBUG`` flag to ``ON``
::
# Whether enable additional graph debug functions
set(USE_GRAPH_RUNTIME_DEBUG ON)
2. Do 'make' tvm, so that it will make the ``libtvm_runtime.so``
3. In frontend script file instead of
``from tvm.contrib import graph_runtime`` import the
``debug_runtime``
``from tvm.contrib.debugger import debug_runtime as graph_runtime``
::
from tvm.contrib.debugger import debug_runtime as graph_runtime
m = graph_runtime.create(graph, lib, ctx, dump_root="/tmp/tvmdbg")
# set inputs
m.set_input('data', tvm.nd.array(data.astype(dtype)))
m.set_input(**params)
# execute
m.run()
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
The outputs are dumped to a temporary folder in ``/tmp`` folder or the
folder specified while creating the runtime.
***************************************
Sample Output
***************************************
The below is the output of running ``tvm/nnvm/tutorials/from_onnnx.py`` with debugger.
::
Node Name Ops Time(us) Time(%) Start Time End Time Shape Inputs Outputs
--------- --- -------- ------- ---------- -------- ----- ------ -------
1_NCHW1c fuse___layout_transform___4 56.52 0.02 15:24:44.177475 15:24:44.177534 (1, 1, 224, 224) 1 1
_contrib_conv2d_nchwc0 fuse__contrib_conv2d_NCHWc 12436.11 3.4 15:24:44.177549 15:24:44.189993 (1, 1, 224, 224, 1) 2 1
relu0_NCHW8c fuse___layout_transform___broadcast_add_relu___layout_transform__ 4375.43 1.2 15:24:44.190027 15:24:44.194410 (8, 1, 5, 5, 1, 8) 2 1
_contrib_conv2d_nchwc1 fuse__contrib_conv2d_NCHWc_1 213108.6 58.28 15:24:44.194440 15:24:44.407558 (1, 8, 224, 224, 8) 2 1
relu1_NCHW8c fuse___layout_transform___broadcast_add_relu___layout_transform__ 2265.57 0.62 15:24:44.407600 15:24:44.409874 (64, 1, 1) 2 1
_contrib_conv2d_nchwc2 fuse__contrib_conv2d_NCHWc_2 104623.15 28.61 15:24:44.409905 15:24:44.514535 (1, 8, 224, 224, 8) 2 1
relu2_NCHW2c fuse___layout_transform___broadcast_add_relu___layout_transform___1 2004.77 0.55 15:24:44.514567 15:24:44.516582 (8, 8, 3, 3, 8, 8) 2 1
_contrib_conv2d_nchwc3 fuse__contrib_conv2d_NCHWc_3 25218.4 6.9 15:24:44.516628 15:24:44.541856 (1, 8, 224, 224, 8) 2 1
reshape1 fuse___layout_transform___broadcast_add_reshape_transpose_reshape 1554.25 0.43 15:24:44.541893 15:24:44.543452 (64, 1, 1) 2 1
"""Graph debug results dumping class."""
import os
import json
import tvm
GRAPH_DUMP_FILE_NAME = '_tvmdbg_graph_dump.json'
class DebugResult(object):
"""Graph debug data module.
Data dump module manage all the debug data formatting.
Output data and input graphs are formatted and dumped to file.
Frontend read these data and graph for visualization.
Parameters
----------
graph_json : str
The graph to be deployed in json format output by nnvm graph. Each operator (tvm_op)
in the graph will have a one to one mapping with the symbol in libmod which is used
to construct a "PackedFunc" .
dump_path : str
Output data path is read/provided from frontend
"""
def __init__(self, graph_json, dump_path):
self._dump_path = dump_path
self._output_tensor_list = []
self._time_list = []
self._parse_graph(graph_json)
# dump the json information
self.dump_graph_json(graph_json)
def _parse_graph(self, graph_json):
"""Parse and extract the NNVM graph and update the nodes, shapes and dltype.
Parameters
----------
graph_json : str or graph class
The graph to be deployed in json format output by nnvm graph.
"""
json_obj = json.loads(graph_json)
self._nodes_list = json_obj['nodes']
self._shapes_list = json_obj['attrs']['shape']
self._dtype_list = json_obj['attrs']['dltype']
self._update_graph_json()
def _update_graph_json(self):
"""update the nodes_list with name, shape and data type,
for temporarily storing the output.
"""
nodes_len = len(self._nodes_list)
for i in range(nodes_len):
node = self._nodes_list[i]
input_list = []
for input_node in node['inputs']:
input_list.append(self._nodes_list[input_node[0]]['name'])
node['inputs'] = input_list
dtype = str("type: " + self._dtype_list[1][i])
if 'attrs' not in node:
node['attrs'] = {}
node['op'] = "param"
else:
node['op'] = node['attrs']['func_name']
node['attrs'].update({"T": dtype})
node['shape'] = self._shapes_list[1][i]
def _cleanup_tensors(self):
"""Remove the tensor dump file (graph wont be removed)
"""
for filename in os.listdir(self._dump_path):
if os.path.isfile(filename) and not filename.endswith(".json"):
os.remove(filename)
def get_graph_nodes(self):
"""Return the nodes list
"""
return self._nodes_list
def get_graph_node_shapes(self):
"""Return the nodes shapes list
"""
return self._shapes_list
def get_graph_node_output_num(self, node):
"""Return the number of outputs of a node
"""
return 1 if node['op'] == 'param' else int(node['attrs']['num_outputs'])
def get_graph_node_dtypes(self):
"""Return the nodes dtype list
"""
return self._dtype_list
def dump_output_tensor(self):
"""Dump the outputs to a temporary folder, the tensors are in numpy format
"""
#cleanup existing tensors before dumping
self._cleanup_tensors()
eid = 0
order = 0
output_tensors = {}
for node, time in zip(self._nodes_list, self._time_list):
num_outputs = self.get_graph_node_output_num(node)
for j in range(num_outputs):
order += time[0]
key = node['name'] + "_" + str(j) + "__" + str(order)
output_tensors[key] = self._output_tensor_list[eid]
eid += 1
with open(os.path.join(self._dump_path, "output_tensors.params"), "wb") as param_f:
param_f.write(save_tensors(output_tensors))
def dump_graph_json(self, graph):
"""Dump json formatted graph.
Parameters
----------
graph : json format
json formatted NNVM graph contain list of each node's
name, shape and type.
"""
graph_dump_file_name = GRAPH_DUMP_FILE_NAME
with open(os.path.join(self._dump_path, graph_dump_file_name), 'w') as outfile:
json.dump(graph, outfile, indent=4, sort_keys=False)
def display_debug_result(self):
"""Displays the debugger result"
"""
header = ["Node Name", "Ops", "Time(us)", "Time(%)", "Start Time", \
"End Time", "Shape", "Inputs", "Outputs"]
lines = ["---------", "---", "--------", "-------", "----------", \
"--------", "-----", "------", "-------"]
eid = 0
data = []
total_time = sum(time[0] for time in self._time_list)
for node, time in zip(self._nodes_list, self._time_list):
num_outputs = self.get_graph_node_output_num(node)
for j in range(num_outputs):
op = node['op']
if node['op'] == 'param':
continue
name = node['name']
shape = str(self._output_tensor_list[eid].shape)
time_us = round(time[0] * 1000000, 2)
time_percent = round(((time[0] / total_time) * 100), 2)
inputs = str(node['attrs']['num_inputs'])
outputs = str(node['attrs']['num_outputs'])
node_data = [name, op, time_us, time_percent, str(time[1]), str(time[2]), \
shape, inputs, outputs]
data.append(node_data)
eid += 1
fmt = ""
for i, _ in enumerate(header):
max_len = len(header[i])
for j, _ in enumerate(data):
item_len = len(str(data[j][i]))
if item_len > max_len:
max_len = item_len
fmt = fmt + "{:<" + str(max_len + 2) + "}"
print(fmt.format(*header))
print(fmt.format(*lines))
for row in data:
print(fmt.format(*row))
def save_tensors(params):
"""Save parameter dictionary to binary bytes.
The result binary bytes can be loaded by the
GraphModule with API "load_params".
Parameters
----------
params : dict of str to NDArray
The parameter dictionary.
Returns
-------
param_bytes: bytearray
Serialized parameters.
"""
_save_tensors = tvm.get_global_func("_save_param_dict")
args = []
for k, v in params.items():
args.append(k)
args.append(tvm.nd.array(v))
return _save_tensors(*args)
"""Graph debug runtime executes TVM debug packed functions."""
import os
import tempfile
import shutil
from datetime import datetime
from tvm._ffi.base import string_types
from tvm.contrib import graph_runtime
from tvm._ffi.function import get_global_func
from . import debug_result
_DUMP_ROOT_PREFIX = "tvmdbg_"
_DUMP_PATH_PREFIX = "_tvmdbg_"
def create(graph_json_str, libmod, ctx, dump_root=None):
"""Create a runtime executor module given a graph and module.
Parameters
----------
graph_json_str : str or graph class
The graph to be deployed in json format output by nnvm graph.
The graph can only contain one operator(tvm_op) that
points to the name of PackedFunc in the libmod.
libmod : tvm.Module
The module of the corresponding function.
ctx : TVMContext
The context to deploy the module, can be local or remote.
dump_root : str
To select which folder the outputs should be kept.
None will make a temp folder in /tmp/tvmdbg<rand_string> and does the dumping
Returns
-------
graph_module : GraphModuleDebug
Debug Runtime graph module that can be used to execute the graph.
"""
if not isinstance(graph_json_str, string_types):
try:
graph_json_str = graph_json_str._tvm_graph_json()
except AttributeError:
raise ValueError("Type %s is not supported" % type(graph_json_str))
try:
fcreate = get_global_func("tvm.graph_runtime_debug.create")
except ValueError:
raise ValueError("Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " \
"config.cmake and rebuild TVM to enable debug mode")
ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
raise NotSupportedError("Remote graph debugging is not supported.")
func_obj = fcreate(graph_json_str, libmod, *device_type_id)
return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root)
class GraphModuleDebug(graph_runtime.GraphModule):
"""Graph debug runtime module.
This is a debug wrapper over the TVM runtime.
Runtime interfaces are wrapped with debug functionalities.
Manage the debug framework to format the debug data and
trigger the user interfaces.
Parameters
----------
module : Module
The interal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under.
graph_json_str : str or graph class
Content of graph json file in string format
dump_root : str
To select which folder the outputs should be kept.
None will make a temp folder in /tmp/tvmdbg<rand_string> and does the dumping
"""
def __init__(self, module, ctx, graph_json_str, dump_root):
self._dump_root = dump_root
self._dump_path = None
self._debug_run = module["debug_run"]
self._get_output_by_layer = module["get_output_by_layer"]
graph_runtime.GraphModule.__init__(self, module)
self._create_debug_env(graph_json_str, ctx)
def _format_context(self, ctx):
return str(ctx[0]).upper().replace("(", ":").replace(")", "")
def _ensure_dir(self, directory):
"""Create a directory if not exists
Parameters
----------
directory : str
File path to create
"""
if not os.path.exists(directory):
os.makedirs(directory, 0o700)
def _get_dump_path(self, ctx):
"""Make the graph and tensor dump folder and return the path.
Parameters
----------
ctx : TVMContext
The context this module is under.
Returns
-------
path : str
Directory path where the graph and node outputs will be stored.
"""
# save to file
folder_name = _DUMP_PATH_PREFIX + "ctx_"
folder_name = folder_name + ctx.replace(":", "_")
path = os.path.join(self._dump_root, folder_name)
self._ensure_dir(path)
return path
def _remove_dump_root(self):
if os.path.isdir(self._dump_root):
shutil.rmtree(self._dump_root)
def _create_debug_env(self, graph_json, ctx):
"""Create UI wrapper framework to handle multiple UI frontends for tvmdbg
Parameters
----------
graph_json : json format
json formatted NNVM graph contain list of each node's name, shape and type.
nodes_list : list
List of all the nodes presented in the graph
ctx : TVMContext
The context this module is under.
"""
# make the dump folder if not given
if not self._dump_root:
self._dump_root = tempfile.mktemp(prefix=_DUMP_ROOT_PREFIX)
# format the context
ctx = self._format_context(ctx)
# updates the dumping directories
self._dump_path = self._get_dump_path(ctx)
# init the debug dumping environment
self.debug_datum = debug_result.DebugResult(graph_json, self._dump_path)
def _run_debug(self):
"""Execute the node spcified with index will be executed.
Each debug output will be copied to the buffer
Time consumed for each execuion will be set as debug output.
"""
for i, node in enumerate(self.debug_datum.get_graph_nodes()):
start_time = datetime.now().time()
time_stamp = self._debug_run(i)
end_time = datetime.now().time()
self.debug_datum._time_list.append([time_stamp, start_time, end_time])
num_outputs = self.debug_datum.get_graph_node_output_num(node)
for j in range(num_outputs):
out_tensor = self._get_output_by_layer(i, j)
self.debug_datum._output_tensor_list.append(out_tensor)
def run(self, **input_dict):
"""Run forward execution of the graph with debug
Parameters
----------
input_dict : dict of str to NDArray
List of input values to be feed to
"""
if input_dict:
self.set_input(**input_dict)
# Step 1. Execute the graph
self._run_debug()
# Step 2. Dump the output tensors to the dump folder
self.debug_datum.dump_output_tensor()
# Step 3. Display the collected information
self.debug_datum.display_debug_result()
def exit(self):
"""Exits the dump folder and all its contents"""
self._remove_dump_root()
......@@ -31,6 +31,31 @@ def create(graph_json_str, libmod, ctx):
graph_json_str = graph_json_str._tvm_graph_json()
except AttributeError:
raise ValueError("Type %s is not supported" % type(graph_json_str))
ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
hmod = rpc_base._ModuleHandle(libmod)
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create")
return GraphModule(fcreate(graph_json_str, hmod, *device_type_id))
fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
def get_device_ctx(libmod, ctx):
"""Parse and validate all the device context(s).
Parameters
----------
libmod : tvm.Module
The module of the corresponding function
ctx : TVMContext or list of TVMContext
Returns
-------
ctx : list of TVMContext
num_rpc_ctx : Number of rpc contexts
device_type_id : List of device type and device id
"""
if isinstance(ctx, TVMContext):
ctx = [ctx]
elif not isinstance(ctx, (list, tuple)):
......@@ -59,14 +84,7 @@ def create(graph_json_str, libmod, ctx):
if 0 < num_rpc_ctx < len(ctx):
raise ValueError("Either all or none of the contexts should be rpc.")
if num_rpc_ctx == len(ctx):
hmod = rpc_base._ModuleHandle(libmod)
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create")
return GraphModule(fcreate(graph_json_str, hmod, *device_type_id))
fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
return ctx, num_rpc_ctx, device_type_id
class GraphModule(object):
......
......@@ -3,6 +3,7 @@
* Implementation of basic API functions
* \file api_base.cc
*/
#include <dmlc/memory_io.h>
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/api_registry.h>
......@@ -33,4 +34,37 @@ TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMSetStream(args[0], args[1], args[2]);
});
TVM_REGISTER_API("_save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u);
constexpr uint64_t TVMNDArrayListMagic = 0xF7E58D4F05049CB7;
size_t num_params = args.size() / 2;
std::vector<std::string> names;
names.reserve(num_params);
std::vector<DLTensor*> arrays;
arrays.reserve(num_params);
for (size_t i = 0; i < num_params * 2; i += 2) {
names.emplace_back(args[i].operator std::string());
arrays.emplace_back(args[i + 1].operator DLTensor*());
}
std::string bytes;
dmlc::MemoryStringStream strm(&bytes);
dmlc::Stream* fo = &strm;
uint64_t header = TVMNDArrayListMagic, reserved = 0;
fo->Write(header);
fo->Write(reserved);
fo->Write(names);
{
uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(sz);
for (size_t i = 0; i < sz; ++i) {
tvm::runtime::SaveDLTensor(fo, arrays[i]);
}
}
TVMByteArray arr;
arr.data = bytes.c_str();
arr.size = bytes.length();
*rv = arr;
});
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file graph_runtime_debug.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
#include <chrono>
#include "../graph_runtime.h"
namespace tvm {
namespace runtime {
/*!
* \brief Graph runtime with debug .
*
* This is the extension of GraphRuntime class used for debugging
* TVM runtime PackedFunc API.
*/
class GraphRuntimeDebug : public GraphRuntime {
public:
/*!
* \brief Run each operation and get the output.
* \param index The index of op which needs to be run.
*/
double DebugRun(size_t index) {
CHECK(index < op_execs().size());
TVMContext ctx = data_entry()[GetEntryId(index, 0)].operator->()->ctx;
auto tbegin = std::chrono::high_resolution_clock::now();
if (op_execs()[index]) {
op_execs()[index]();
}
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto tend = std::chrono::high_resolution_clock::now();
double time = std::chrono::duration_cast<std::chrono::duration<double> >(
tend - tbegin).count();
return time;
}
/*!
* \brief Run each operation and get the output.
* \param index The index of op which needs to be returned.
* \param eid The Entry id of the op.
*/
NDArray GetOutputByLayer(int index, int eid) {
return data_entry()[GetEntryId(index, eid)];
}
/*!
* \brief GetFunction Get the function based on input.
* \param name The function which needs to be invoked.
* \param sptr_to_self Packed function pointer.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self);
/*!
* \brief Get the node index given the name of node.
* \param name The name of the node.
* \return The index of node.
*/
int GetNodeIndex(const std::string& name) const {
for (size_t nid = 0; nid < GetNumOfNodes(); ++nid) {
if (GetNodeName(nid) == name) {
return static_cast<int>(nid);
}
}
LOG(FATAL) << "cannot find " << name << " among nodex";
return -1;
}
/*!
* \brief Copy index-th node to data_out.
*
* This method will do a partial run of the the graph
* from begining upto the index-th node and return output of index-th node.
* This is costly operation and suggest to use only for debug porpose.
*
* \param index: The index of the node.
* \param data_out the node data.
*/
void DebugGetNodeOutput(int index, DLTensor* data_out) {
CHECK_LT(static_cast<size_t>(index), op_execs().size());
uint32_t eid = index;
for (size_t i = 0; i < op_execs().size(); ++i) {
if (op_execs()[i]) op_execs()[i]();
if (static_cast<int>(i) == index) break;
}
data_entry()[eid].CopyTo(data_out);
}
};
/*!
* \brief GetFunction Get the function based on input.
* \param name The function which needs to be invoked.
* \param sptr_to_self Packed function pointer.
*/
PackedFunc GraphRuntimeDebug::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
// return member functions during query.
if (name == "debug_run") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->DebugRun(args[0]);
});
} else if (name == "get_output_by_layer") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetOutputByLayer(args[0], args[1]);
});
} else if (name == "debug_get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) {
this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
} else {
this->DebugGetNodeOutput(args[0], args[1]);
}
});
} else {
return GraphRuntime::GetFunction(name, sptr_to_self);
}
}
/*!
* \brief GraphRuntimeDebugCreate Get the function based on input.
* \param sym_json The graph symbol in json format.
* \param m Compiled module which will be loaded.
* \param ctxs All devices contexts.
*/
Module GraphRuntimeDebugCreate(const std::string& sym_json,
const tvm::runtime::Module& m,
const std::vector<TVMContext>& ctxs) {
std::shared_ptr<GraphRuntimeDebug> exec = std::make_shared<GraphRuntimeDebug>();
exec->Init(sym_json, m, ctxs);
return Module(exec);
}
TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4)
<< "The expected number of arguments for graph_runtime.create is "
"at least 4, but it has "
<< args.num_args;
*rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args));
});
} // namespace runtime
} // namespace tvm
import os
import tvm
import numpy as np
import json
from tvm.contrib.debugger import debug_runtime as graph_runtime
def test_graph_simple():
n = 4
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
node0 = {"op": "null", "name": "x", "inputs": []}
node1 = {"op": "tvm_op", "name": "add",
"inputs": [[0, 0, 0]],
"attrs": {"func_name": "myadd",
"flatten_data": "1",
"num_inputs" : "1",
"num_outputs" : "1"}}
nodes = [node0, node1]
arg_nodes = [0]
node_row_ptr = [0, 1, 2]
outputs = [[1, 0, 0]]
shape = (4,)
attrs = {
"shape" : ["list_shape", [shape, shape]],
"dltype" : ["list_str", ["float32", "float32"]],
"storage_id" : ["list_int", [0, 1]],
}
graph = {"nodes": nodes,
"arg_nodes": arg_nodes,
"node_row_ptr": node_row_ptr,
"heads": outputs,
"attrs": attrs}
graph = json.dumps(graph)
def check_verify():
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
mlib = tvm.build(s, [A, B], "llvm", name="myadd")
try:
mod = graph_runtime.create(graph, mlib, tvm.cpu(0))
except ValueError:
return
a = np.random.uniform(size=(n,)).astype(A.dtype)
mod.set_input(x=a)
#verify dumproot created
directory = mod._dump_path
assert(os.path.exists(directory))
#verify graph is there
GRAPH_DUMP_FILE_NAME = '_tvmdbg_graph_dump.json'
assert(len(os.listdir(directory)) == 1)
#verify the file name is proper
assert(os.path.exists(os.path.join(directory, GRAPH_DUMP_FILE_NAME)))
mod.run()
#Verify the tensors are dumped
assert(len(os.listdir(directory)) > 1)
#verify the output is correct
out = mod.get_output(0, tvm.nd.empty((n,)))
np.testing.assert_equal(out.asnumpy(), a + 1)
mod.exit()
#verify dump root delete after cleanup
assert(not os.path.exists(directory))
check_verify()
if __name__ == "__main__":
test_graph_simple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment