Commit 7182201d by lixiaoquan Committed by Tianqi Chen

Fix a bug in nnvm to relay converter. (#2756)

parent cc112c10
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-argument
"""Convert an NNVM graph to Relay."""
import json
import numpy
from tvm import relay, nd
......@@ -241,7 +240,7 @@ def _split(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', 0)
return op.split(children[0], indices_or_sections, axis).astuple()
return op.split(children[0], indices_or_sections, axis)
def _squeeze(children, attrs, odtype='float32'):
axis = attrs.get_int_tuple('axis', None)
......@@ -441,12 +440,10 @@ def to_relay(graph, shape_dict, dtype_dict, params):
graph = graph.apply(["InferShape", "InferType"])
shape = graph.json_attr("shape")
dtype = [graph_attr.TCODE_TO_DTYPE[di] for di in graph.json_attr("dtype")]
heads = [x[0] for x in json.loads(graph.json())['heads']]
gidx = graph.index
relay_map = {}
fn_params = []
output_ids = []
for nid, node in enumerate(gidx.nodes):
children = []
......@@ -468,9 +465,6 @@ def to_relay(graph, shape_dict, dtype_dict, params):
fn_params.append(v)
relay_map[nid] = v
else:
if nid in heads:
output_ids.append(nid)
if op_name in NNVM_OP_2_RELAY_OP:
str_attrs = StrAttrsDict(attrs)
call = NNVM_OP_2_RELAY_OP[op_name](children, str_attrs, odtype)
......@@ -479,7 +473,14 @@ def to_relay(graph, shape_dict, dtype_dict, params):
raise Exception(
"nnvm.to_relay: unsupported operator: {0}".format(op_name))
outputs = [relay_map[nid] for nid in output_ids]
outputs = []
for nid, idx, _ in gidx.output_entries:
output = relay_map[nid]
if isinstance(output, expr.TupleWrapper):
outputs.append(output[idx])
else:
outputs.append(output)
if len(outputs) == 1:
body = outputs[0]
else:
......
......@@ -72,6 +72,23 @@ def test_forward_dqn():
verify_nnvm_to_relay(model, params, data_shape=(1, 4, 84, 84))
def test_forward_split_concatenate():
shape = (2, 16)
tensor = nnvm.sym.Variable("data", shape=shape)
splited = nnvm.sym.split(tensor, indices_or_sections=2, axis=1)
concatenated = nnvm.sym.concatenate(*splited, axis=1)
params = {}
verify_nnvm_to_relay(splited[0], params, data_shape=shape)
verify_nnvm_to_relay(splited[1], params, data_shape=shape)
verify_nnvm_to_relay(splited, params, data_shape=shape)
verify_nnvm_to_relay(concatenated, params, data_shape=shape)
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -80,3 +97,4 @@ if __name__ == '__main__':
test_forward_inception_v3()
test_forward_densenet()
test_forward_dqn()
test_forward_split_concatenate()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment