Commit f7278101 by Neo Chien Committed by Zhi

[GraphRuntime] Support parameter out in the graph runtime debug (#4598)

* [GraphRuntime] Support parameter out in the graph runtime debug

* Dummy commit to trigger build
parent 227c7af4
......@@ -85,7 +85,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
Parameters
----------
module : Module
The interal tvm module that holds the actual graph functions.
The internal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under.
......@@ -188,7 +188,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
out_tensor = array(out_tensor)
self.debug_datum._output_tensor_list.append(out_tensor)
def debug_get_output(self, node, out):
def debug_get_output(self, node, out=None):
"""Run graph up to node and get the output to out
Parameters
......@@ -199,12 +199,11 @@ class GraphModuleDebug(graph_runtime.GraphModule):
out : NDArray
The output array container
"""
ret = None
if isinstance(node, str):
output_tensors = self.debug_datum.get_output_tensors()
try:
ret = output_tensors[node]
except:
out = output_tensors[node]
except KeyError:
node_list = output_tensors.keys()
raise RuntimeError(
"Node "
......@@ -215,10 +214,10 @@ class GraphModuleDebug(graph_runtime.GraphModule):
)
elif isinstance(node, int):
output_tensors = self.debug_datum._output_tensor_list
ret = output_tensors[node]
out = output_tensors[node]
else:
raise RuntimeError("Require node index or name only.")
return ret
return out
def run(self, **input_dict):
"""Run forward execution of the graph with debug
......@@ -244,7 +243,6 @@ class GraphModuleDebug(graph_runtime.GraphModule):
ret = self._run_individual(number, repeat, min_repeat_ms)
return ret.strip(",").split(",") if ret else []
def exit(self):
"""Exits the dump folder and all its contents"""
self._remove_dump_root()
......@@ -22,6 +22,7 @@ from .._ffi.function import get_global_func
from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base
def create(graph_json_str, libmod, ctx):
"""Create a runtime executor module given a graph and module.
Parameters
......@@ -57,6 +58,7 @@ def create(graph_json_str, libmod, ctx):
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
def get_device_ctx(libmod, ctx):
"""Parse and validate all the device context(s).
Parameters
......@@ -112,12 +114,12 @@ class GraphModule(object):
Parameters
----------
module : Module
The interal tvm module that holds the actual graph functions.
The internal tvm module that holds the actual graph functions.
Attributes
----------
module : Module
The interal tvm module that holds the actual graph functions.
The internal tvm module that holds the actual graph functions.
"""
def __init__(self, module):
......@@ -142,7 +144,7 @@ class GraphModule(object):
The input key
params : dict of str to NDArray
Additonal arguments
Additional arguments
"""
if key is not None:
self._get_input(key).copyfrom(value)
......@@ -211,7 +213,7 @@ class GraphModule(object):
return self._get_output(index)
def debug_get_output(self, node, out):
"""Run graph upto node and get the output to out
"""Run graph up to node and get the output to out
Parameters
----------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment