Commit c5e1da93 by Rasterer Committed by Tianqi Chen

[TOPI] Improve performance for dilated convolution (#2107)

parent 59c70a0e
......@@ -113,11 +113,6 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou
else:
dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
dilation_args = (1, 1, dilation_h, dilation_w) if len(kernel.shape) == 4\
else (1, 1, dilation_h, dilation_w, 1)
kernel = dilate(kernel, dilation_args)
if len(kernel.shape) == 4:
pre_packed = False
CO, _, KH, KW = get_const_tuple(kernel.shape)
......@@ -126,11 +121,13 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou
CO, _, KH, KW, VC = get_const_tuple(kernel.shape)
CO = CO * VC
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (KH, KW))
dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
OH = (IH + pad_top + pad_bottom - KH) // HSTR + 1
OW = (IW + pad_left + pad_right - KW) // WSTR + 1
OH = (IH + pad_top + pad_bottom - dilated_kernel_h) // HSTR + 1
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right])
# ==================== define configuration space ====================
......@@ -171,14 +168,22 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
kvshape = (CO // VC, CI, KH, KW, VC)
ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC)
oshape = (N, CO, OH, OW)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw:
data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw],
name='data_vec')
if dilation_h != 1 or dilation_w != 1:
# undilate input data
dvshape = (N, OH // VH, OW // VW, CI, KH, KW, VH, VW)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, kh, kw, vh, vw:
data_pad[n][ci][(h*VH+vh)*HSTR+kh*dilation_h]
[(w*VW+vw)*WSTR+kw*dilation_w],
name='data_vec_undilated')
else:
dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw:
data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw],
name='data_vec')
if pre_packed:
kernel_vec = kernel
......@@ -191,10 +196,16 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou
kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw')
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
if dilation_h != 1 or dilation_w != 1:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
else:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
......@@ -240,7 +251,10 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
# mark parallel
s[last].parallel(co)
_, h, _, _, _, _ = s[data_vec].op.axis
if data_vec.op.name == 'data_vec_undilated':
_, h, _, _, _, _, _, _ = s[data_vec].op.axis
else:
_, h, _, _, _, _ = s[data_vec].op.axis
s[data_vec].parallel(h)
if kernel_vec.op.name == 'kernel_vec':
......
......@@ -118,7 +118,10 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
s[data_pad].compute_inline()
# schedule data packing
_, h, w, ci, vh, vw = s[data_vec].op.axis
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and data_vec.op.name == 'data_vec_undilated':
_, h, w, ci, _, _, vh, vw = s[data_vec].op.axis
else:
_, h, w, ci, vh, vw = s[data_vec].op.axis
tile_and_bind3d(s, data_vec, h, w, ci, 1)
if vh.dom.extent.value < max_unroll:
s[data_vec].unroll(vh)
......
......@@ -6,7 +6,6 @@ from collections import namedtuple
import numpy as np
import tvm
from .dilate import dilate
from .pad import pad
from .util import get_pad_tuple
from ..util import simplify, const_matrix, get_const_tuple
......@@ -128,17 +127,16 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
else:
dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
Filter = dilate(Filter, (1, 1, dilation_h, dilation_w))
batch, in_channel, in_height, in_width = Input.shape
num_filter, channel, kernel_h, kernel_w = Filter.shape
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w))
# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# compute graph
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
......@@ -150,7 +148,8 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
return tvm.compute(
(batch, out_channel, out_height, out_width),
lambda nn, ff, yy, xx: tvm.sum(
temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) *
temp[nn, rc, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w].astype(out_dtype) *
Filter[ff, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx]), tag="conv2d_nchw")
......@@ -195,17 +194,16 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
else:
dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
Filter = dilate(Filter, (dilation_h, dilation_w, 1, 1))
in_height, in_width, in_channel, batch = Input.shape
kernel_h, kernel_w, channel, num_filter = Filter.shape
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w))
# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
pad_before = [pad_top, pad_left, 0, 0]
pad_after = [pad_down, pad_right, 0, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
......@@ -215,7 +213,8 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
Output = tvm.compute(
(out_height, out_width, out_channel, batch),
lambda yy, xx, ff, nn: tvm.sum(
PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn].astype(out_dtype) *
PaddedInput[yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w,
rc, nn].astype(out_dtype) *
Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2dOutput", tag="conv2d_hwcn")
return Output
......@@ -259,17 +258,16 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
else:
dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
Filter = dilate(Filter, (dilation_h, dilation_w, 1, 1))
batch, in_height, in_width, in_channel = Input.shape
kernel_h, kernel_w, channel, num_filter = Filter.shape
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w))
# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
......@@ -279,7 +277,8 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
Output = tvm.compute(
(batch, out_height, out_width, out_channel),
lambda nn, yy, xx, ff: tvm.sum(
PaddedInput[nn, yy * stride_h + ry, xx * stride_w + rx, rc].astype(out_dtype) *
PaddedInput[nn, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2dOutput", tag="conv2d_nhwc")
return Output
......
......@@ -72,18 +72,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
else:
dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
Filter = dilate(Filter, (1, 1, dilation_h, dilation_w))
batch, in_channel, in_height, in_width = Input.shape
# shape of dilated kernel
filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape
dilated_kernel_h = (filter_height - 1) * dilation_h + 1
dilated_kernel_w = (filter_width - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (filter_height, filter_width))
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# padding stage
pad_before = [0, 0, pad_top, pad_left]
......@@ -95,7 +94,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
Output = tvm.compute(
(batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum(
(PaddedInput[b, c/channel_multiplier, i*stride_h+di, j*stride_w+dj].astype(out_dtype) *
(PaddedInput[b, c/channel_multiplier, i*stride_h+di*dilation_h,
j*stride_w+dj*dilation_w].astype(out_dtype) *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
......@@ -143,18 +143,17 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
else:
dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
Filter = dilate(Filter, (dilation_h, dilation_w, 1, 1))
batch, in_height, in_width, in_channel = Input.shape
# shape of dilated kernel
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
dilated_kernel_h = (filter_height - 1) * dilation_h + 1
dilated_kernel_w = (filter_width - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (filter_height, filter_width))
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# padding stage
pad_before = [0, pad_top, pad_left, 0]
......@@ -166,8 +165,8 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
Output = tvm.compute(
(batch, out_height, out_width, out_channel),
lambda b, i, j, c: tvm.sum(
(PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier].astype(
out_dtype) *
(PaddedInput[b, i*stride_h + di*dilation_h, j*stride_w + dj*dilation_w,
c/channel_multiplier].astype(out_dtype) *
Filter[di, dj, c/channel_multiplier, c%channel_multiplier].astype(out_dtype)),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
......
......@@ -9,7 +9,6 @@ from .. import nn
from ..util import get_const_tuple
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_alter_layout, _get_workload as _get_conv2d_workload
from ..nn.dilate import dilate
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
from ..nn.pad import pad
......@@ -89,9 +88,6 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
else:
dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
kernel = dilate(kernel, (1, 1, dilation_h, dilation_w))
HPAD, WPAD = padding
HSTR, WSTR = strides
......@@ -101,8 +97,10 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
pad_height = in_height + 2 * HPAD
pad_width = in_width + 2 * WPAD
out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1
out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
out_height = (in_height + 2 * HPAD - dilated_kernel_h) // HSTR + 1
out_width = (in_width + 2 * WPAD - dilated_kernel_w) // WSTR + 1
# pack data
DOPAD = (HPAD != 0 or WPAD != 0)
......@@ -136,8 +134,8 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh, ic%ic_bn,
ow*WSTR+kw].astype(out_dtype) *
tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh*dilation_h, ic%ic_bn,
ow*WSTR+kw*dilation_w].astype(out_dtype) *
kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn,
oc_block].astype(out_dtype),
axis=[ic, kh, kw]), name='conv')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment