Commit 35a7eac5 by Joshua Z. Zhang Committed by Tianqi Chen

[Frontend] Onnx improvement (#165)

* fix recently released layers

* fix fc layers with partial infer_shape
parent 51e78516
......@@ -3,6 +3,8 @@
from __future__ import absolute_import as _abs
import tvm
from .. import symbol as _sym
from .. import graph as _graph
from .. compiler import graph_util
from .common import Renamer, AttrConverter as AttrCvt
__all__ = ['from_onnx']
......@@ -60,9 +62,9 @@ def _pooling(name):
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), _revert_caffe2_pad)},
# very weird attributes here in onnx, force check
excludes=['dilations'],
ignores=['dilations'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': True},
extras={'ceil_mode': False},
custom_check=_dimension_constraint())
def _conv():
......@@ -90,7 +92,7 @@ def _batch_norm():
return AttrCvt(
op_name='batch_norm',
disables=['momentum'],
ignores=['spatial', 'is_test'])
ignores=['spatial', 'is_test', 'consumed_inputs'])
# compatible operators that do NOT require any conversion.
......@@ -100,6 +102,7 @@ _identity_list = []
_convert_map = {
# defs/experimental
'FC' : AttrCvt('dense', ignores=['axis', 'axis_w']),
'SpatialBN' : _batch_norm(),
# defs/generator
# 'Constant'
......@@ -200,7 +203,7 @@ def _convert_operator(op_name, attrs, identity_list=None, convert_map=None):
elif op_name in convert_map:
op_name, attrs = convert_map[op_name](attrs)
else:
_raise_not_supported('Operator: ' + op_name)
raise NotImplementedError("Operator {} not implemented.".format(op_name))
op = getattr(_sym, op_name, None)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
......@@ -267,10 +270,11 @@ class GraphProto(object):
new_attr = self._fix_channels(new_op, new_attr, list(node.input))
self._fix_bias_shape(node.op_type, graph.node[idx-1].op_type, node.input)
op = new_op(name=node_name, *inputs, **new_attr)
assert len(node.output) == len(op.list_output_names()), (
"Number of output mismatch {} vs {}.".format(
len(node.output), len(op.list_output_names())))
for k, i in zip(list(node.output), range(len(node.output))):
node_output = self._fix_outputs(op_name, node.output)
assert len(node_output) == len(op.list_output_names()), (
"Number of output mismatch {} vs {} in {}.".format(
len(node_output), len(op.list_output_names()), op_name))
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
# now return the outputs
out = [self._nodes[i] for i in graph.output]
......@@ -310,6 +314,15 @@ class GraphProto(object):
raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
return attrs
def _fix_outputs(self, op, outputs):
"""A hack to handle dropout or similar operator that have more than one out
in ONNX.
"""
if op == 'Dropout':
assert len(outputs) == 2, "ONNX have two outputs for dropout layer."
outputs = outputs[:-1]
return outputs
def _fix_bias(self, op, attrs, num_inputs):
"""A hack for 'use_bias' attribute since onnx don't provide this attribute,
we have to check the number of inputs to decide it."""
......@@ -340,17 +353,24 @@ class GraphProto(object):
"""
if op not in [_sym.conv2d, _sym.conv2d_transpose, _sym.dense]:
return attrs
weight_name = self._renames[inputs[1]]
if not weight_name in self._params:
raise ValueError("Unable to get channels/units attr from onnx graph.")
if inputs[1] not in self._renames:
assert inputs[1] in self._nodes
g = _graph.create(self._nodes[inputs[1]])
shape_dict = {k: v.shape for k, v in self._params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0]
else:
wshape = self._params[weight_name].shape
assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape)
channels = wshape[0]
if op in [_sym.dense]:
attrs['units'] = channels
weight_name = self._renames[inputs[1]]
if not weight_name in self._params:
raise ValueError("Unable to get channels/units attr from onnx graph.")
else:
attrs['channels'] = channels
wshape = self._params[weight_name].shape
assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape)
channels = wshape[0]
if op in [_sym.dense]:
attrs['units'] = channels
else:
attrs['channels'] = channels
return attrs
def from_onnx(graph):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment