Commit 35a7eac5 by Joshua Z. Zhang Committed by Tianqi Chen

[Frontend] Onnx improvement (#165)

* fix recently released layers

* fix fc layers with partial infer_shape
parent 51e78516
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import symbol as _sym from .. import symbol as _sym
from .. import graph as _graph
from .. compiler import graph_util
from .common import Renamer, AttrConverter as AttrCvt from .common import Renamer, AttrConverter as AttrCvt
__all__ = ['from_onnx'] __all__ = ['from_onnx']
...@@ -60,9 +62,9 @@ def _pooling(name): ...@@ -60,9 +62,9 @@ def _pooling(name):
'kernel_shape': 'pool_size', 'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), _revert_caffe2_pad)}, 'pads': ('padding', (0, 0), _revert_caffe2_pad)},
# very weird attributes here in onnx, force check # very weird attributes here in onnx, force check
excludes=['dilations'], ignores=['dilations'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout? # TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': True}, extras={'ceil_mode': False},
custom_check=_dimension_constraint()) custom_check=_dimension_constraint())
def _conv(): def _conv():
...@@ -90,7 +92,7 @@ def _batch_norm(): ...@@ -90,7 +92,7 @@ def _batch_norm():
return AttrCvt( return AttrCvt(
op_name='batch_norm', op_name='batch_norm',
disables=['momentum'], disables=['momentum'],
ignores=['spatial', 'is_test']) ignores=['spatial', 'is_test', 'consumed_inputs'])
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
...@@ -100,6 +102,7 @@ _identity_list = [] ...@@ -100,6 +102,7 @@ _identity_list = []
_convert_map = { _convert_map = {
# defs/experimental # defs/experimental
'FC' : AttrCvt('dense', ignores=['axis', 'axis_w']), 'FC' : AttrCvt('dense', ignores=['axis', 'axis_w']),
'SpatialBN' : _batch_norm(),
# defs/generator # defs/generator
# 'Constant' # 'Constant'
...@@ -200,7 +203,7 @@ def _convert_operator(op_name, attrs, identity_list=None, convert_map=None): ...@@ -200,7 +203,7 @@ def _convert_operator(op_name, attrs, identity_list=None, convert_map=None):
elif op_name in convert_map: elif op_name in convert_map:
op_name, attrs = convert_map[op_name](attrs) op_name, attrs = convert_map[op_name](attrs)
else: else:
_raise_not_supported('Operator: ' + op_name) raise NotImplementedError("Operator {} not implemented.".format(op_name))
op = getattr(_sym, op_name, None) op = getattr(_sym, op_name, None)
if not op: if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name)) raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
...@@ -267,10 +270,11 @@ class GraphProto(object): ...@@ -267,10 +270,11 @@ class GraphProto(object):
new_attr = self._fix_channels(new_op, new_attr, list(node.input)) new_attr = self._fix_channels(new_op, new_attr, list(node.input))
self._fix_bias_shape(node.op_type, graph.node[idx-1].op_type, node.input) self._fix_bias_shape(node.op_type, graph.node[idx-1].op_type, node.input)
op = new_op(name=node_name, *inputs, **new_attr) op = new_op(name=node_name, *inputs, **new_attr)
assert len(node.output) == len(op.list_output_names()), ( node_output = self._fix_outputs(op_name, node.output)
"Number of output mismatch {} vs {}.".format( assert len(node_output) == len(op.list_output_names()), (
len(node.output), len(op.list_output_names()))) "Number of output mismatch {} vs {} in {}.".format(
for k, i in zip(list(node.output), range(len(node.output))): len(node_output), len(op.list_output_names()), op_name))
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i] self._nodes[k] = op[i]
# now return the outputs # now return the outputs
out = [self._nodes[i] for i in graph.output] out = [self._nodes[i] for i in graph.output]
...@@ -310,6 +314,15 @@ class GraphProto(object): ...@@ -310,6 +314,15 @@ class GraphProto(object):
raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
return attrs return attrs
def _fix_outputs(self, op, outputs):
"""A hack to handle dropout or similar operator that have more than one out
in ONNX.
"""
if op == 'Dropout':
assert len(outputs) == 2, "ONNX have two outputs for dropout layer."
outputs = outputs[:-1]
return outputs
def _fix_bias(self, op, attrs, num_inputs): def _fix_bias(self, op, attrs, num_inputs):
"""A hack for 'use_bias' attribute since onnx don't provide this attribute, """A hack for 'use_bias' attribute since onnx don't provide this attribute,
we have to check the number of inputs to decide it.""" we have to check the number of inputs to decide it."""
...@@ -340,17 +353,24 @@ class GraphProto(object): ...@@ -340,17 +353,24 @@ class GraphProto(object):
""" """
if op not in [_sym.conv2d, _sym.conv2d_transpose, _sym.dense]: if op not in [_sym.conv2d, _sym.conv2d_transpose, _sym.dense]:
return attrs return attrs
weight_name = self._renames[inputs[1]] if inputs[1] not in self._renames:
if not weight_name in self._params: assert inputs[1] in self._nodes
raise ValueError("Unable to get channels/units attr from onnx graph.") g = _graph.create(self._nodes[inputs[1]])
shape_dict = {k: v.shape for k, v in self._params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0]
else: else:
wshape = self._params[weight_name].shape weight_name = self._renames[inputs[1]]
assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape) if not weight_name in self._params:
channels = wshape[0] raise ValueError("Unable to get channels/units attr from onnx graph.")
if op in [_sym.dense]:
attrs['units'] = channels
else: else:
attrs['channels'] = channels wshape = self._params[weight_name].shape
assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape)
channels = wshape[0]
if op in [_sym.dense]:
attrs['units'] = channels
else:
attrs['channels'] = channels
return attrs return attrs
def from_onnx(graph): def from_onnx(graph):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment