Commit 8acc413c by Neo Chien Committed by Zhi

[Relay][Frontend][ONNX] Support auto_pad in Conv and ConvTranspose (#4563)

parent f076c839
......@@ -66,6 +66,17 @@ def revert_caffe2_pad(pads):
return pads
def get_pad_pair(input1d, kernel1d, stride1d):
"""infer pad size"""
if input1d % stride1d == 0:
pad = max(kernel1d - stride1d, 0)
else:
pad = max(kernel1d - (input1d % stride1d), 0)
pad_before = pad // 2
pad_after = pad - pad_before
return [pad_before, pad_after]
def onnx_storage_order2layout(storage_order):
"""converter of onnx storage order parameter to tvm storage order format"""
if storage_order not in (0, 1):
......@@ -202,13 +213,36 @@ class Conv(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
out = AttrCvt(op_name=dimension_picker('conv'),
# infer pads for auto_pad
if 'auto_pad' in attr:
attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
input_shape = infer_shape(inputs[0])
in_h, in_w = input_shape[2], input_shape[3]
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
dilation_h, dilation_w = attr['dilations']
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w)
attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1])
elif attr['auto_pad'] == 'VALID':
attr['pads'] = (0, 0)
elif attr['auto_pad'] == 'NOTSET':
pass
else:
msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
attr.pop('auto_pad')
out = AttrCvt(
op_name=dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), revert_caffe2_pad),
'group': ('groups', 1)},
ignores=['auto_pad'],
custom_check=dimension_constraint())(inputs[:2], attr, params)
use_bias = len(inputs) == 3
if use_bias:
......@@ -226,6 +260,29 @@ class ConvTranspose(OnnxOpConverter):
attr['channels'] = channels
groups = attr.pop('group')
attr['groups'] = groups
# infer pads for auto_pad
if 'auto_pad' in attr:
attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
input_shape = infer_shape(inputs[0])
in_h, in_w = input_shape[2], input_shape[3]
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
dilation_h, dilation_w = attr['dilations']
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w)
attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1])
elif attr['auto_pad'] == 'VALID':
attr['pads'] = (0, 0)
elif attr['auto_pad'] == 'NOTSET':
pass
else:
msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
attr.pop('auto_pad')
out = AttrCvt(
op_name=dimension_picker('conv', '_transpose'),
transforms={
......
......@@ -77,10 +77,13 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
return tvm_output.asnumpy()
def get_onnxruntime_output(model, x, dtype='float32'):
def get_onnxruntime_output(model, inputs, dtype='float32'):
import onnxruntime.backend
rep = onnxruntime.backend.prepare(model, 'CPU')
x = x.astype(dtype)
if isinstance(inputs, list) and len(inputs) > 1:
ort_out = rep.run(inputs)
else:
x = inputs.astype(dtype)
ort_out = rep.run(x)[0]
return ort_out
......@@ -1746,6 +1749,83 @@ def test_or():
verify_or(indata=[x, y], dtype=bool)
def verify_conv(x_shape, w_shape, y_shape, p):
node = helper.make_node('Conv',
inputs=['x', 'W'],
outputs=['y'],
kernel_shape=[3, 3],
# Default values for other attributes:
# strides=[1, 1],
# dilations=[1, 1],
# groups=1
pads=p,)
graph = helper.make_graph([node],
'conv_test',
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))])
model = helper.make_model(graph, producer_name='conv_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=x_shape).astype('float32')
W = np.random.uniform(size=w_shape).astype('float32')
tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape)
onnx_out = get_onnxruntime_output(model, [x, W], 'float32')[0]
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_conv():
# Convolution with padding
# (1, 1, 5, 5) input tensor
# (1, 1, 3, 3) tensor for convolution weights
# (1, 1, 5, 5) output tensor
# [1, 1, 1, 1] list for pads
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [1, 1, 1, 1])
# Convolution without padding
# (1, 1, 5, 5) input tensor
# (1, 1, 3, 3) tensor for convolution weights
# (1, 1, 3, 3) output tensor
# [0, 0, 0, 0] list for pads
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), [0, 0, 0, 0])
def verify_convtranspose(x_shape, w_shape, y_shape, p):
node = onnx.helper.make_node("ConvTranspose",
inputs=["x", "W"],
outputs=['y'],
strides=[3, 2],
group=1,
kernel_shape=[3, 3],
pads=p)
graph = helper.make_graph([node],
'verify_convtranspose_test',
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))])
model = helper.make_model(graph, producer_name='convtranspose_trest')
for target, ctx in ctx_list():
x = np.random.uniform(size=x_shape).astype('float32')
W = np.random.uniform(size=w_shape).astype('float32')
tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape)
onnx_out = get_onnxruntime_output(model, [x, W], 'float32')[0]
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_convtranspose():
# Convolution Transpose with padding
# (1, 1, 3, 3) input tensor
# (1, 2, 3, 3) tensor for convolution weights
# (1, 2, 7, 3) output tensor
# [1, 2, 1, 2] list for pads
verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2])
if __name__ == '__main__':
test_flatten()
test_reshape()
......@@ -1800,3 +1880,5 @@ if __name__ == '__main__':
test_or()
test_depth_to_space()
test_space_to_depth()
test_conv()
test_convtranspose()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment