Commit c3df7726 by Wenhao Hu Committed by Tianqi Chen

support t attr in onnx (#1300)

parent 59a8d099
...@@ -571,13 +571,20 @@ class GraphProto(object): ...@@ -571,13 +571,20 @@ class GraphProto(object):
op_name = node.op_type op_name = node.op_type
attr = self._parse_attr(node.attribute) attr = self._parse_attr(node.attribute)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input] inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
op = self._convert_operator(op_name, inputs, attr, opset) if op_name == "Constant":
node_output = self._fix_outputs(op_name, node.output) t_proto = self._parse_attr(node.attribute)["value"]
assert len(node_output) == len(op.list_output_names()), ( self._num_param += 1
"Number of output mismatch {} vs {} in {}.".format( self._params[node.output[0]] = self._parse_array(t_proto)
len(node_output), len(op.list_output_names()), op_name)) self._nodes[node.output[0]] = _sym.Variable(name=node.output[0],
for k, i in zip(list(node_output), range(len(node_output))): shape=list(t_proto.dims))
self._nodes[k] = op[i] else:
op = self._convert_operator(op_name, inputs, attr, opset)
node_output = self._fix_outputs(op_name, node.output)
assert len(node_output) == len(op.list_output_names()), (
"Number of output mismatch {} vs {} in {}.".format(
len(node_output), len(op.list_output_names()), op_name))
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
# now return the outputs # now return the outputs
out = [self._nodes[self._parse_value_proto(i)] for i in graph.output] out = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
if len(out) > 1: if len(out) > 1:
...@@ -615,11 +622,18 @@ class GraphProto(object): ...@@ -615,11 +622,18 @@ class GraphProto(object):
if list(getattr(a, f)): if list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is allowed" assert a.name not in attrs, "Only one type of attr is allowed"
attrs[a.name] = tuple(getattr(a, f)) attrs[a.name] = tuple(getattr(a, f))
for f in ['t', 'g']: for f in ['t']:
if a.HasField(f):
attrs[a.name] = getattr(a, f)
for f in ['tensors']:
if list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is allowed"
attrs[a.name] = tuple(getattr(a, f))
for f in ['g']:
if a.HasField(f): if a.HasField(f):
raise NotImplementedError( raise NotImplementedError(
"Filed {} is not supported in nnvm.".format(f)) "Filed {} is not supported in nnvm.".format(f))
for f in ['tensors', 'graphs']: for f in ['graphs']:
if list(getattr(a, f)): if list(getattr(a, f)):
raise NotImplementedError( raise NotImplementedError(
"Filed {} is not supported in nnvm.".format(f)) "Filed {} is not supported in nnvm.".format(f))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment