Commit 71ffc2b6 by Tang, Cheng Committed by Tianqi Chen

don't rename onnx input tensors (#491)

parent 9f6e0c59
...@@ -558,17 +558,13 @@ class GraphProto(object): ...@@ -558,17 +558,13 @@ class GraphProto(object):
i_name = self._parse_value_proto(i) i_name = self._parse_value_proto(i)
if i_name in self._params: if i_name in self._params:
# i is a param instead of input # i is a param instead of input
name_param = 'param_{}'.format(self._num_param)
self._num_param += 1 self._num_param += 1
self._params[name_param] = self._params.pop(i_name) self._params[i_name] = self._params.pop(i_name)
self._nodes[name_param] = _sym.Variable( self._nodes[i_name] = _sym.Variable(
name=name_param, shape=self._params[name_param].shape) name=i_name, shape=self._params[i_name].shape)
self._renames[i_name] = name_param
else: else:
name_input = 'input_{}'.format(self._num_input)
self._num_input += 1 self._num_input += 1
self._nodes[name_input] = _sym.Variable(name=name_input) self._nodes[i_name] = _sym.Variable(name=i_name)
self._renames[i_name] = name_input
# construct nodes, nodes are stored as directed acyclic graph # construct nodes, nodes are stored as directed acyclic graph
for node in graph.node: for node in graph.node:
op_name = node.op_type op_name = node.op_type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment