Commit 37414470 by Siva Committed by Tianqi Chen

[GRAPH] Include default metadata description in graph. (#2770)

parent df722397
......@@ -20,7 +20,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
from __future__ import absolute_import
import json
from collections import defaultdict
from collections import defaultdict, OrderedDict
import attr
from . import _backend
from . import compile_engine
......@@ -348,13 +348,31 @@ class GraphRuntimeCodegen(ExprFunctor):
attrs["device_index"] = ["list_int", device_types]
attrs["dltype"] = ["list_str", dltypes]
json_dict = {
# Metadata definitions
def nested_defaultdict():
return defaultdict(nested_defaultdict)
metadata = nested_defaultdict()
for node_id in arg_nodes:
node_name = nodes[node_id]['name']
if node_name not in self.params:
metadata['signatures']['default']['inputs'][node_name]['id'] = node_id
metadata['signatures']['default']['inputs'][node_name]['dtype'] = dltypes[node_id]
metadata['signatures']['default']['inputs'][node_name]['shape'] = shapes[node_id]
for node_id in heads:
node_name = nodes[node_id[0]]['name']
metadata['signatures']['default']['outputs'][node_name]['id'] = node_id[0]
metadata['signatures']['default']['outputs'][node_name]['dtype'] = dltypes[node_id[0]]
metadata['signatures']['default']['outputs'][node_name]['shape'] = shapes[node_id[0]]
# Keep 'metadata' always at end
json_dict = OrderedDict({
"nodes": nodes,
"arg_nodes": arg_nodes,
"heads": heads,
"attrs": attrs,
"node_row_ptr": node_row_ptr
}
"node_row_ptr": node_row_ptr,
"metadata": metadata
})
return json.dumps(json_dict, indent=2)
......
......@@ -318,6 +318,8 @@ class GraphRuntime : public ModuleNode {
} else if (key == "attrs") {
reader->Read(&attrs_);
bitmask |= 16;
} else if (key == "metadata") {
break;
} else {
LOG(FATAL) << "key " << key << " is not supported";
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment