Commit 3272e6cb by Yuta Hinokuma Committed by Tianqi Chen

[WIP] [Relay] [NNVM] [Frontend] implement MaxPool-8 and MaxPool-10 (#3114)

parent 09960e30
......@@ -27,6 +27,13 @@ from .onnx_caffe2_utils import dimension_picker, dimension_constraint, \
__all__ = ['from_onnx']
def onnx_storage_order2layout(storage_order):
if storage_order not in (0, 1):
raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1')
return 'NCHW' if sotrage_order == 0 else 'NHWC'
class OnnxOpConverter(object):
""" A helper class for holding onnx op converters.
"""
......@@ -207,8 +214,38 @@ class Gemm(OnnxOpConverter):
class MaxPool(Pool):
""" Operator converter for MaxPool
"""
name = 'max_pool'
@classmethod
def _impl_v8(cls, inputs, attr, params):
return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'storage_order': ('layout', 'NCHW', onnx_storage_order2layout),
},
# very weird attributes here in onnx, force check
ignores=['dilations', 'auto_pad'],
# TODO(higumachan): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=dimension_constraint())(inputs, attr, params)
@classmethod
def _impl_v10(cls, inputs, attr, params):
return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'storage_order': ('layout', 'NCHW', onnx_storage_order2layout),
'ceil_mode': 'ceil_mode'
},
# very weird attributes here in onnx, force check
ignores=['dilations', 'auto_pad'],
custom_check=dimension_constraint())(inputs, attr, params)
class Mul(Elemwise):
name = 'mul'
......
......@@ -52,6 +52,15 @@ def revert_caffe2_pad(pads):
'Number of pads must be either 2 or 4.')
return pads
def onnx_storage_order2layout(storage_order):
"""converter of onnx storage order parameter to tvm storage order format"""
if storage_order not in (0, 1):
raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1')
return 'NCHW' if sotrage_order == 0 else 'NHWC'
def dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
......@@ -60,6 +69,7 @@ def dimension_constraint():
return _dim_check, "Only 2d kernel supported."
class OnnxOpConverter(object):
""" A helper class for holding onnx op converters.
"""
......@@ -108,6 +118,7 @@ class Elemwise(OnnxOpConverter):
inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2)
return get_relay_op(op_name)(*inputs)
class Pool(OnnxOpConverter):
""" A helper class for pool op converters.
"""
......@@ -247,6 +258,7 @@ class Gemm(OnnxOpConverter):
inputs[1], units=channels)
return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])
class MatMul(OnnxOpConverter):
""" Operator converter for MatMul.
"""
......@@ -257,9 +269,40 @@ class MatMul(OnnxOpConverter):
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t)
class MaxPool(Pool):
""" Operator converter for MaxPool
"""
name = 'max_pool'
@classmethod
def _impl_v8(cls, inputs, attr, params):
return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'storage_order': ('layout', 'NCHW', onnx_storage_order2layout),
},
# very weird attributes here in onnx, force check
ignores=['dilations', 'auto_pad'],
# TODO(higumachan): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=dimension_constraint())(inputs, attr, params)
@classmethod
def _impl_v10(cls, inputs, attr, params):
return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'storage_order': ('layout', 'NCHW', onnx_storage_order2layout),
'ceil_mode': 'ceil_mode'
},
# very weird attributes here in onnx, force check
ignores=['dilations', 'auto_pad'],
custom_check=dimension_constraint())(inputs, attr, params)
class Mul(Elemwise):
name = 'multiply'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment