Commit f7278101 by Neo Chien Committed by Zhi

[GraphRuntime] Support parameter out in the graph runtime debug (#4598)

* [GraphRuntime] Support parameter out in the graph runtime debug

* Dummy commit to trigger build
parent 227c7af4
...@@ -85,7 +85,7 @@ class GraphModuleDebug(graph_runtime.GraphModule): ...@@ -85,7 +85,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
Parameters Parameters
---------- ----------
module : Module module : Module
The interal tvm module that holds the actual graph functions. The internal tvm module that holds the actual graph functions.
ctx : TVMContext ctx : TVMContext
The context this module is under. The context this module is under.
...@@ -188,7 +188,7 @@ class GraphModuleDebug(graph_runtime.GraphModule): ...@@ -188,7 +188,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
out_tensor = array(out_tensor) out_tensor = array(out_tensor)
self.debug_datum._output_tensor_list.append(out_tensor) self.debug_datum._output_tensor_list.append(out_tensor)
def debug_get_output(self, node, out): def debug_get_output(self, node, out=None):
"""Run graph up to node and get the output to out """Run graph up to node and get the output to out
Parameters Parameters
...@@ -199,12 +199,11 @@ class GraphModuleDebug(graph_runtime.GraphModule): ...@@ -199,12 +199,11 @@ class GraphModuleDebug(graph_runtime.GraphModule):
out : NDArray out : NDArray
The output array container The output array container
""" """
ret = None
if isinstance(node, str): if isinstance(node, str):
output_tensors = self.debug_datum.get_output_tensors() output_tensors = self.debug_datum.get_output_tensors()
try: try:
ret = output_tensors[node] out = output_tensors[node]
except: except KeyError:
node_list = output_tensors.keys() node_list = output_tensors.keys()
raise RuntimeError( raise RuntimeError(
"Node " "Node "
...@@ -215,10 +214,10 @@ class GraphModuleDebug(graph_runtime.GraphModule): ...@@ -215,10 +214,10 @@ class GraphModuleDebug(graph_runtime.GraphModule):
) )
elif isinstance(node, int): elif isinstance(node, int):
output_tensors = self.debug_datum._output_tensor_list output_tensors = self.debug_datum._output_tensor_list
ret = output_tensors[node] out = output_tensors[node]
else: else:
raise RuntimeError("Require node index or name only.") raise RuntimeError("Require node index or name only.")
return ret return out
def run(self, **input_dict): def run(self, **input_dict):
"""Run forward execution of the graph with debug """Run forward execution of the graph with debug
...@@ -244,7 +243,6 @@ class GraphModuleDebug(graph_runtime.GraphModule): ...@@ -244,7 +243,6 @@ class GraphModuleDebug(graph_runtime.GraphModule):
ret = self._run_individual(number, repeat, min_repeat_ms) ret = self._run_individual(number, repeat, min_repeat_ms)
return ret.strip(",").split(",") if ret else [] return ret.strip(",").split(",") if ret else []
def exit(self): def exit(self):
"""Exits the dump folder and all its contents""" """Exits the dump folder and all its contents"""
self._remove_dump_root() self._remove_dump_root()
...@@ -22,6 +22,7 @@ from .._ffi.function import get_global_func ...@@ -22,6 +22,7 @@ from .._ffi.function import get_global_func
from .._ffi.runtime_ctypes import TVMContext from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base from ..rpc import base as rpc_base
def create(graph_json_str, libmod, ctx): def create(graph_json_str, libmod, ctx):
"""Create a runtime executor module given a graph and module. """Create a runtime executor module given a graph and module.
Parameters Parameters
...@@ -57,6 +58,7 @@ def create(graph_json_str, libmod, ctx): ...@@ -57,6 +58,7 @@ def create(graph_json_str, libmod, ctx):
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id)) return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
def get_device_ctx(libmod, ctx): def get_device_ctx(libmod, ctx):
"""Parse and validate all the device context(s). """Parse and validate all the device context(s).
Parameters Parameters
...@@ -112,12 +114,12 @@ class GraphModule(object): ...@@ -112,12 +114,12 @@ class GraphModule(object):
Parameters Parameters
---------- ----------
module : Module module : Module
The interal tvm module that holds the actual graph functions. The internal tvm module that holds the actual graph functions.
Attributes Attributes
---------- ----------
module : Module module : Module
The interal tvm module that holds the actual graph functions. The internal tvm module that holds the actual graph functions.
""" """
def __init__(self, module): def __init__(self, module):
...@@ -142,7 +144,7 @@ class GraphModule(object): ...@@ -142,7 +144,7 @@ class GraphModule(object):
The input key The input key
params : dict of str to NDArray params : dict of str to NDArray
Additonal arguments Additional arguments
""" """
if key is not None: if key is not None:
self._get_input(key).copyfrom(value) self._get_input(key).copyfrom(value)
...@@ -211,7 +213,7 @@ class GraphModule(object): ...@@ -211,7 +213,7 @@ class GraphModule(object):
return self._get_output(index) return self._get_output(index)
def debug_get_output(self, node, out): def debug_get_output(self, node, out):
"""Run graph upto node and get the output to out """Run graph up to node and get the output to out
Parameters Parameters
---------- ----------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment