Commit bbf82e0e by bindog Committed by Zhi

[Fix] Add more pad_mode support for onnx converter (#4029)

* [Fix] Add more pad_mode support for onnx converter

* robustness fix
parent 4f712c79
...@@ -326,15 +326,20 @@ class Pad(OnnxOpConverter): ...@@ -326,15 +326,20 @@ class Pad(OnnxOpConverter):
for i in range(dims): for i in range(dims):
pad_width.append((pads[i], pads[i+dims])) pad_width.append((pads[i], pads[i+dims]))
attr['pad_width'] = pad_width attr['pad_width'] = pad_width
pad_mode = attr.get('mode', 'constant').decode('utf-8')
if pad_mode in ['constant', 'edge', 'reflect']:
attr['pad_mode'] = pad_mode
attr.pop('mode', None)
else:
raise tvm.error.OpAttributeInvalid(
'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')
return AttrCvt( return AttrCvt(
_op.nn.pad, _op.nn.pad,
transforms={ transforms={
'value': 'pad_value', 'value': 'pad_value',
}, },
ignores=['mode'], )(inputs, attr, params)
custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
'split mode != constant'))(inputs, attr, params)
@classmethod @classmethod
def _impl_v2(cls, inputs, attr, params): def _impl_v2(cls, inputs, attr, params):
...@@ -344,15 +349,20 @@ class Pad(OnnxOpConverter): ...@@ -344,15 +349,20 @@ class Pad(OnnxOpConverter):
for i in range(dims): for i in range(dims):
pad_width.append((pads[i], pads[i+dims])) pad_width.append((pads[i], pads[i+dims]))
attr['pad_width'] = pad_width attr['pad_width'] = pad_width
pad_mode = attr.get('mode', 'constant').decode('utf-8')
if pad_mode in ['constant', 'edge', 'reflect']:
attr['pad_mode'] = pad_mode
attr.pop('mode', None)
else:
raise tvm.error.OpAttributeInvalid(
'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')
return AttrCvt( return AttrCvt(
'pad', 'pad',
transforms={ transforms={
'value': 'pad_value', 'value': 'pad_value',
}, },
ignores=['mode'], )(inputs, attr, params)
custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
'split mode != constant'))(inputs, attr, params)
class ParametricSoftPlus(OnnxOpConverter): class ParametricSoftPlus(OnnxOpConverter):
......
...@@ -781,13 +781,23 @@ def test_constantfill(): ...@@ -781,13 +781,23 @@ def test_constantfill():
verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6)) verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6))
def verify_pad(indata, pads, value=0.0): def verify_pad(indata, pads, mode='constant', value=0.0):
indata = np.array(indata).astype(np.float32) indata = np.array(indata).astype(np.float32)
# numpy expect result # numpy expect result
len_dim = len(pads) // 2 len_dim = len(pads) // 2
np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)] np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)]
outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value)
# onnx graph # onnx graph
if mode in ['edge', 'reflect']:
outdata = np.pad(indata, pad_width=np_pads, mode=mode)
node = helper.make_node(
'Pad',
inputs=['input'],
outputs=['output'],
mode=mode,
pads=pads,
)
else:
outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value)
node = helper.make_node( node = helper.make_node(
'Pad', 'Pad',
inputs=['input'], inputs=['input'],
...@@ -809,9 +819,11 @@ def verify_pad(indata, pads, value=0.0): ...@@ -809,9 +819,11 @@ def verify_pad(indata, pads, value=0.0):
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
def test_pad(): def test_pad():
verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 0.0) verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 'constant', 0.0)
verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 0.0) verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 'constant', 0.0)
verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 5.0) verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 'constant', 5.0)
verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge')
verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect')
def verify_reduce_x(name, indata, axis, keepdims): def verify_reduce_x(name, indata, axis, keepdims):
indata = np.array(indata).astype(np.float32) indata = np.array(indata).astype(np.float32)
...@@ -1266,7 +1278,6 @@ if __name__ == '__main__': ...@@ -1266,7 +1278,6 @@ if __name__ == '__main__':
test_forward_arg_min_max() test_forward_arg_min_max()
test_softmax() test_softmax()
test_constantfill() test_constantfill()
test_pad()
test_reduce_max() test_reduce_max()
test_reduce_min() test_reduce_min()
test_reduce_sum() test_reduce_sum()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment