Commit 37414470 by Siva Committed by Tianqi Chen

[GRAPH] Include default metadata description in graph. (#2770)

parent df722397
...@@ -20,7 +20,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system. ...@@ -20,7 +20,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
from __future__ import absolute_import from __future__ import absolute_import
import json import json
from collections import defaultdict from collections import defaultdict, OrderedDict
import attr import attr
from . import _backend from . import _backend
from . import compile_engine from . import compile_engine
...@@ -348,13 +348,31 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -348,13 +348,31 @@ class GraphRuntimeCodegen(ExprFunctor):
attrs["device_index"] = ["list_int", device_types] attrs["device_index"] = ["list_int", device_types]
attrs["dltype"] = ["list_str", dltypes] attrs["dltype"] = ["list_str", dltypes]
json_dict = { # Metadata definitions
def nested_defaultdict():
return defaultdict(nested_defaultdict)
metadata = nested_defaultdict()
for node_id in arg_nodes:
node_name = nodes[node_id]['name']
if node_name not in self.params:
metadata['signatures']['default']['inputs'][node_name]['id'] = node_id
metadata['signatures']['default']['inputs'][node_name]['dtype'] = dltypes[node_id]
metadata['signatures']['default']['inputs'][node_name]['shape'] = shapes[node_id]
for node_id in heads:
node_name = nodes[node_id[0]]['name']
metadata['signatures']['default']['outputs'][node_name]['id'] = node_id[0]
metadata['signatures']['default']['outputs'][node_name]['dtype'] = dltypes[node_id[0]]
metadata['signatures']['default']['outputs'][node_name]['shape'] = shapes[node_id[0]]
# Keep 'metadata' always at end
json_dict = OrderedDict({
"nodes": nodes, "nodes": nodes,
"arg_nodes": arg_nodes, "arg_nodes": arg_nodes,
"heads": heads, "heads": heads,
"attrs": attrs, "attrs": attrs,
"node_row_ptr": node_row_ptr "node_row_ptr": node_row_ptr,
} "metadata": metadata
})
return json.dumps(json_dict, indent=2) return json.dumps(json_dict, indent=2)
......
...@@ -318,6 +318,8 @@ class GraphRuntime : public ModuleNode { ...@@ -318,6 +318,8 @@ class GraphRuntime : public ModuleNode {
} else if (key == "attrs") { } else if (key == "attrs") {
reader->Read(&attrs_); reader->Read(&attrs_);
bitmask |= 16; bitmask |= 16;
} else if (key == "metadata") {
break;
} else { } else {
LOG(FATAL) << "key " << key << " is not supported"; LOG(FATAL) << "key " << key << " is not supported";
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment