Commit d43aab07 by Yao Wang Committed by Yizhi Liu

Support x86 dilation conv2d and improve multi-batch conv2d (#3308)

* Support x86 dilation conv2d and improve multi-batch conv2d

* Fix lint
parent bfa966a8
......@@ -500,8 +500,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
# we keep them for debug convenience when dumping autotvm workload
HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
assert (dh, dw) == (1, 1), "Does not support dilation"
dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
else (dilation, dilation)
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
......@@ -514,6 +514,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
num_filter = oc_chunk * oc_bn
groups = ic_chunk // ic_chunk_group
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
if cfg.is_fallback:
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
......@@ -521,8 +524,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
strides, padding, out_dtype)
# output shape
out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1
out_width = (iw + 2 * WPAD - kernel_width) // WSTR + 1
out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1
out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1
oshape = (n, oc_chunk, out_height, out_width, oc_bn)
# DOPAD
......@@ -548,8 +551,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner]
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh*dilation_h,
ow*WSTR+kw*dilation_w,
ic_f_inner * n_elems + ic_s_inner]
.astype(out_dtype) *
kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
......@@ -575,7 +579,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
# else: fp implementation
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw,
tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh*dilation_h,
ow*WSTR+kw*dilation_w,
ic%ic_bn].astype(out_dtype) *
kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block],
axis=[ic, kh, kw]),
......
......@@ -134,7 +134,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
parallel_axis = s[A].fuse(batch, ic_chunk, ih)
s[A].parallel(parallel_axis)
C, O = conv_out, last
......@@ -146,7 +146,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh_outer)
parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer)
s[CC].compute_at(s[C], parallel_axis)
if C == O:
s[C].parallel(parallel_axis)
......@@ -172,7 +172,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer)
parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
......@@ -203,7 +203,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
parallel_axis = s[A].fuse(batch, ic_chunk, ih)
s[A].parallel(parallel_axis)
C, O = conv_out, last
......@@ -215,7 +215,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh_outer)
parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer)
s[CC].compute_at(s[C], parallel_axis)
if C == O:
s[C].parallel(parallel_axis)
......@@ -246,7 +246,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer)
parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
......
......@@ -143,10 +143,10 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
C, O = conv_out, last
CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh)
parallel_axis = s[C].fuse(batch, oc_chunk, oh)
s[C].vectorize(oc_block)
if C == O:
s[C].parallel(parallel_axis)
......@@ -171,7 +171,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
......@@ -214,10 +214,10 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
C, O = conv_out, last
CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh)
parallel_axis = s[C].fuse(batch, oc_chunk, oh)
s[C].vectorize(oc_block)
if C == O:
s[C].parallel(parallel_axis)
......@@ -251,7 +251,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
......
......@@ -49,7 +49,6 @@ def _transform_bias(bias, bn):
def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"):
assert dilation == 1, "conv2d_NCHWc does not support dilation for now."
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding))
......@@ -79,7 +78,8 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
w_np = np.random.uniform(size=(num_filter, in_channel, kernel, kernel)).astype(dtype)
b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype)
c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
if add_bias:
c_np += b_np
if add_relu:
......@@ -149,8 +149,8 @@ def test_conv2d_NCHWc():
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
# disable dilation test since it is not supported by NCHW[x]c conv for now.
# verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, dilation=2)
# dilation
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, dilation=2)
# batch size
verify_conv2d_NCHWc(4, 64, 56, 64, 3, 1, 1)
......
......@@ -47,9 +47,6 @@ from gluoncv import model_zoo, data, utils
#
# To get best performance fo SSD on Intel graphics,
# change target argument to 'opencl -device=intel_graphics'
#
# SSD with VGG as body network is not supported yet since
# x86 conv2d schedule doesn't support dilation.
supported_model = [
'ssd_512_resnet50_v1_voc',
......@@ -57,6 +54,8 @@ supported_model = [
'ssd_512_resnet101_v2_voc',
'ssd_512_mobilenet1.0_voc',
'ssd_512_mobilenet1.0_coco',
'ssd_300_vgg16_atrous_voc'
'ssd_512_vgg16_atrous_coco',
]
model_name = supported_model[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment