Commit ba3ddcd7 by Bob.Liu Committed by Yizhi Liu

[FRONTEND][ONNX]add Pad, ReduceMax, ReduceMin, ReduceMean and ReduceSum OP (#2061)

* add Pad,ReduceMax,ReduceMin,ReduceMean,ReduceSum for onnx frontend

* fixed pylint error and warning for frontend.onnx file

* add implement v2 for Pad in onnx frontend

* compatible with python 3.x

* disable too-many-lines pylint check in frontend onnx

* use random values instead in onnx frontend testing
parent 5712ea6b
# pylint: disable=import-self, invalid-name, unused-argument # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines
"""ONNX: Open Neural Network Exchange frontend.""" """ONNX: Open Neural Network Exchange frontend."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import numpy as np import numpy as np
...@@ -31,10 +31,9 @@ class OnnxOpConverter(object): ...@@ -31,10 +31,9 @@ class OnnxOpConverter(object):
max([i for i, v in enumerate(versions) if v == opset]) - 1] max([i for i, v in enumerate(versions) if v == opset]) - 1]
if hasattr(cls, '_impl_v{}'.format(version)): if hasattr(cls, '_impl_v{}'.format(version)):
return getattr(cls, '_impl_v{}'.format(version)) return getattr(cls, '_impl_v{}'.format(version))
else: raise NotImplementedError(
raise NotImplementedError( 'opset version {} of {} not implemented'.format(
'opset version {} of {} not implemented'.format( version, cls.__name__))
version, cls.__name__))
class Elemwise(OnnxOpConverter): class Elemwise(OnnxOpConverter):
...@@ -200,22 +199,44 @@ class Mul(Elemwise): ...@@ -200,22 +199,44 @@ class Mul(Elemwise):
class Pad(OnnxOpConverter): class Pad(OnnxOpConverter):
""" Operator converter for Pad.
"""
@classmethod @classmethod
def _impl_v1(cls, inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
# get number of channels pad_width = []
channels = _infer_channels(inputs[1], params, True) pads = attr.pop('paddings')
attr['channels'] = channels dims = int(len(pads) / 2)
groups = attr.pop('group') for i in range(dims):
attr['groups'] = groups pad_width.append((pads[i], pads[i+dims]))
attr['pad_width'] = pad_width
return AttrCvt(
op_name='pad',
transforms={
'value': 'pad_value',
},
ignores=['mode'],
custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
'split mode != constant'))(inputs, attr, params)
@classmethod
def _impl_v2(cls, inputs, attr, params):
pad_width = []
pads = attr.pop('pads')
dims = int(len(pads) / 2)
for i in range(dims):
pad_width.append((pads[i], pads[i+dims]))
attr['pad_width'] = pad_width
return AttrCvt( return AttrCvt(
op_name='pad', op_name='pad',
transforms={ transforms={
'value': 'pad_value', 'value': 'pad_value',
'pads': 'pad_width'
}, },
custom_check=lambda attrs: attrs.get('mode') == 'constant')( ignores=['mode'],
inputs, attr, params) custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
'split mode != constant'))(inputs, attr, params)
class ParametricSoftPlus(OnnxOpConverter): class ParametricSoftPlus(OnnxOpConverter):
...@@ -368,8 +389,7 @@ def _dimension_picker(prefix, surfix=''): ...@@ -368,8 +389,7 @@ def _dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape'] kernel = attr['kernel_shape']
if len(kernel) == 2: if len(kernel) == 2:
return prefix + '2d' + surfix return prefix + '2d' + surfix
else: raise NotImplementedError("Only 2d kernel supported.")
raise NotImplementedError("Only 2d kernel supported.")
return _impl return _impl
...@@ -659,14 +679,13 @@ class ConstantFill(OnnxOpConverter): ...@@ -659,14 +679,13 @@ class ConstantFill(OnnxOpConverter):
transforms={'value': 'fill_value'}, transforms={'value': 'fill_value'},
ignores=['dtype'])(inputs, attr) ignores=['dtype'])(inputs, attr)
return _sym.cast(out, dtype=attr['dtype'].decode("utf-8")) return _sym.cast(out, dtype=attr['dtype'].decode("utf-8"))
else: if 'extra_shape' in attr:
if 'extra_shape' in attr: shape = shape + attr.pop('extra_shape')
shape = shape + attr.pop('extra_shape')
return AttrCvt( return AttrCvt(
op_name='full', op_name='full',
transforms={'value': 'fill_value'}, transforms={'value': 'fill_value'},
extras={'shape':shape})(inputs, attr) extras={'shape':shape})(inputs, attr)
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -758,10 +777,10 @@ def _get_convert_map(opset): ...@@ -758,10 +777,10 @@ def _get_convert_map(opset):
'LRN': LRN.get_converter(opset), 'LRN': LRN.get_converter(opset),
# defs/reduction # defs/reduction
'ReduceMax': AttrCvt('max', {'axes', 'axis'}), 'ReduceMax': AttrCvt('max', {'axes': 'axis'}),
'ReduceMin': AttrCvt('min', {'axes', 'axis'}), 'ReduceMin': AttrCvt('min', {'axes': 'axis'}),
'ReduceSum': AttrCvt('sum', {'axes', 'axis'}), 'ReduceSum': AttrCvt('sum', {'axes': 'axis'}),
# 'ReduceMean' 'ReduceMean': AttrCvt('mean', {'axes': 'axis'}),
# 'ReduceProd' # 'ReduceProd'
# 'ReduceLogSumExp' # 'ReduceLogSumExp'
'ArgMax': ArgMax.get_converter(opset), 'ArgMax': ArgMax.get_converter(opset),
......
...@@ -712,6 +712,117 @@ def test_constantfill(): ...@@ -712,6 +712,117 @@ def test_constantfill():
verify_constantfill(False, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32') verify_constantfill(False, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6)) verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6))
def verify_pad(indata, pads, value=0.0):
indata = np.array(indata).astype(np.float32)
# numpy expect result
len_dim = len(pads) // 2
np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)]
outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value)
# onnx graph
node = helper.make_node(
'Pad',
inputs=['input'],
outputs=['output'],
mode='constant',
pads=pads,
value=value
)
graph = helper.make_graph([node],
'pad_test',
inputs = [helper.make_tensor_value_info("input",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("output",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='pad_test')
# tvm result
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
def test_pad():
verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 0.0)
verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 0.0)
verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 5.0)
def verify_reduce_x(name, indata, axis, keepdims):
indata = np.array(indata).astype(np.float32)
# numpy expect result
if name == 'ReduceMax':
outdata = np.maximum.reduce(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceMin':
outdata = np.minimum.reduce(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceSum':
outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceMean':
outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1)
else:
raise Exception('unsupport op: {}'.format(name))
if len(np.asarray(outdata).shape) == 0:
outdata = np.asarray([outdata])
# onnx graph
if axis is None:
node = helper.make_node(name, inputs=['input'], outputs=['output'],
keepdims=keepdims)
else:
node = helper.make_node(name, inputs=['input'], outputs=['output'],
axis=axis, keepdims=keepdims)
graph = helper.make_graph([node],
'{}_test'.format(name),
inputs = [helper.make_tensor_value_info("input",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("output",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='{}_test'.format(name))
# tvm result
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
def test_reduce_max():
verify_reduce_x("ReduceMax",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceMax",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceMax",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def test_reduce_min():
verify_reduce_x("ReduceMin",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceMin",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceMin",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def test_reduce_sum():
verify_reduce_x("ReduceSum",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceSum",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceSum",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def test_reduce_mean():
verify_reduce_x("ReduceMean",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceMean",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceMean",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def verify_split(indata, outdatas, split, axis=0): def verify_split(indata, outdatas, split, axis=0):
indata = np.array(indata).astype(np.float32) indata = np.array(indata).astype(np.float32)
outdatas = [np.array(o).astype(np.float32) for o in outdatas] outdatas = [np.array(o).astype(np.float32) for o in outdatas]
...@@ -772,4 +883,9 @@ if __name__ == '__main__': ...@@ -772,4 +883,9 @@ if __name__ == '__main__':
test_forward_arg_min_max() test_forward_arg_min_max()
test_softmax() test_softmax()
test_constantfill() test_constantfill()
test_pad()
test_reduce_max()
test_reduce_min()
test_reduce_sum()
test_reduce_mean()
test_split() test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment