Commit c3df7726 by Wenhao Hu Committed by Tianqi Chen

support t attr in onnx (#1300)

parent 59a8d099
...@@ -571,6 +571,13 @@ class GraphProto(object): ...@@ -571,6 +571,13 @@ class GraphProto(object):
op_name = node.op_type op_name = node.op_type
attr = self._parse_attr(node.attribute) attr = self._parse_attr(node.attribute)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input] inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
if op_name == "Constant":
t_proto = self._parse_attr(node.attribute)["value"]
self._num_param += 1
self._params[node.output[0]] = self._parse_array(t_proto)
self._nodes[node.output[0]] = _sym.Variable(name=node.output[0],
shape=list(t_proto.dims))
else:
op = self._convert_operator(op_name, inputs, attr, opset) op = self._convert_operator(op_name, inputs, attr, opset)
node_output = self._fix_outputs(op_name, node.output) node_output = self._fix_outputs(op_name, node.output)
assert len(node_output) == len(op.list_output_names()), ( assert len(node_output) == len(op.list_output_names()), (
...@@ -615,11 +622,18 @@ class GraphProto(object): ...@@ -615,11 +622,18 @@ class GraphProto(object):
if list(getattr(a, f)): if list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is allowed" assert a.name not in attrs, "Only one type of attr is allowed"
attrs[a.name] = tuple(getattr(a, f)) attrs[a.name] = tuple(getattr(a, f))
for f in ['t', 'g']: for f in ['t']:
if a.HasField(f):
attrs[a.name] = getattr(a, f)
for f in ['tensors']:
if list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is allowed"
attrs[a.name] = tuple(getattr(a, f))
for f in ['g']:
if a.HasField(f): if a.HasField(f):
raise NotImplementedError( raise NotImplementedError(
"Filed {} is not supported in nnvm.".format(f)) "Filed {} is not supported in nnvm.".format(f))
for f in ['tensors', 'graphs']: for f in ['graphs']:
if list(getattr(a, f)): if list(getattr(a, f)):
raise NotImplementedError( raise NotImplementedError(
"Filed {} is not supported in nnvm.".format(f)) "Filed {} is not supported in nnvm.".format(f))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment