Commit a1f59908 by Tianqi Chen

also save graph attributes (#78)

parent 1b1590b5
......@@ -199,6 +199,7 @@ Graph LoadJSON(Graph src) {
// save a graph to json
Graph SaveJSON(Graph src) {
JSONGraph jgraph;
jgraph.attrs = src.attrs;
std::unordered_map<Node*, uint32_t> node2index;
jgraph.node_row_ptr.push_back(0);
DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) {
......
......@@ -11,6 +11,20 @@ def test_json_pass():
g2 = ret.apply('LoadJSON')
assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json')
def test_json_pass_with_attr():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv', stride=(2,2))
g = graph.create(y)
g._set_json_attr('version', '0.1.0')
ret = g.apply('SaveJSON')
json_str = ret.json_attr('json')
print(json_str)
ret._set_json_attr('json', json_str)
g2 = ret.apply('LoadJSON')
assert g2.json_attr('version') == '0.1.0'
def test_graph_json_attr():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv', stride=(2,2))
......@@ -129,6 +143,7 @@ def test_plan_memory():
if __name__ == "__main__":
test_json_pass_with_attr()
test_order_mutation_pass()
test_graph_json_attr()
test_json_pass()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment