Commit e545c9a6 by Takeo Imai Committed by Tianqi Chen

Change as graph.input can be either ValueInfoProto or string (#186) (#192)

* graph.input can be either ValueInfoProto or string

* pylint
parent e7bbb32c
......@@ -245,18 +245,24 @@ class GraphProto(object):
raise ValueError("Tensor's name is required.")
self._params[init_tensor.name] = self._parse_array(init_tensor)
for i in graph.input:
if i in self._params:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
try:
i_name = i.name
except AttributeError:
i_name = i
if i_name in self._params:
# i is a param instead of input
name_param = 'param_{}'.format(self._num_param)
self._num_param += 1
self._params[name_param] = self._params.pop(i)
self._params[name_param] = self._params.pop(i_name)
self._nodes[name_param] = _sym.Variable(name=name_param)
self._renames[i] = name_param
self._renames[i_name] = name_param
else:
name_input = 'input_{}'.format(self._num_input)
self._num_input += 1
self._nodes[name_input] = _sym.Variable(name=name_input)
self._renames[i] = name_input
self._renames[i_name] = name_input
# construct nodes, nodes are stored as directed acyclic graph
for idx, node in enumerate(graph.node):
op_name = node.op_type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment