Commit c5e1da93 by Rasterer Committed by Tianqi Chen

[TOPI] Improve performance for dilated convolution (#2107)

parent 59c70a0e
...@@ -113,11 +113,6 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou ...@@ -113,11 +113,6 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou
else: else:
dilation_h, dilation_w = dilation dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
dilation_args = (1, 1, dilation_h, dilation_w) if len(kernel.shape) == 4\
else (1, 1, dilation_h, dilation_w, 1)
kernel = dilate(kernel, dilation_args)
if len(kernel.shape) == 4: if len(kernel.shape) == 4:
pre_packed = False pre_packed = False
CO, _, KH, KW = get_const_tuple(kernel.shape) CO, _, KH, KW = get_const_tuple(kernel.shape)
...@@ -126,11 +121,13 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou ...@@ -126,11 +121,13 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou
CO, _, KH, KW, VC = get_const_tuple(kernel.shape) CO, _, KH, KW, VC = get_const_tuple(kernel.shape)
CO = CO * VC CO = CO * VC
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (KH, KW)) dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
OH = (IH + pad_top + pad_bottom - dilated_kernel_h) // HSTR + 1
OH = (IH + pad_top + pad_bottom - KH) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
OW = (IW + pad_left + pad_right - KW) // WSTR + 1
data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right]) data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right])
# ==================== define configuration space ==================== # ==================== define configuration space ====================
...@@ -171,14 +168,22 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou ...@@ -171,14 +168,22 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou
VH = cfg["tile_oh"].size[-1] VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1] VW = cfg["tile_ow"].size[-1]
dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
kvshape = (CO // VC, CI, KH, KW, VC) kvshape = (CO // VC, CI, KH, KW, VC)
ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC) ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC)
oshape = (N, CO, OH, OW) oshape = (N, CO, OH, OW)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw: if dilation_h != 1 or dilation_w != 1:
data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], # undilate input data
name='data_vec') dvshape = (N, OH // VH, OW // VW, CI, KH, KW, VH, VW)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, kh, kw, vh, vw:
data_pad[n][ci][(h*VH+vh)*HSTR+kh*dilation_h]
[(w*VW+vw)*WSTR+kw*dilation_w],
name='data_vec_undilated')
else:
dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw:
data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw],
name='data_vec')
if pre_packed: if pre_packed:
kernel_vec = kernel kernel_vec = kernel
...@@ -191,10 +196,16 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou ...@@ -191,10 +196,16 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou
kh = tvm.reduce_axis((0, KH), name='kh') kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw') kw = tvm.reduce_axis((0, KW), name='kw')
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ if dilation_h != 1 or dilation_w != 1:
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) * conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), tvm.sum(data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype) *
axis=[ci, kh, kw]), name='conv') kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
else:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
output = tvm.compute(oshape, lambda n, co, h, w: output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
...@@ -240,7 +251,10 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, ...@@ -240,7 +251,10 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
# mark parallel # mark parallel
s[last].parallel(co) s[last].parallel(co)
_, h, _, _, _, _ = s[data_vec].op.axis if data_vec.op.name == 'data_vec_undilated':
_, h, _, _, _, _, _, _ = s[data_vec].op.axis
else:
_, h, _, _, _, _ = s[data_vec].op.axis
s[data_vec].parallel(h) s[data_vec].parallel(h)
if kernel_vec.op.name == 'kernel_vec': if kernel_vec.op.name == 'kernel_vec':
......
...@@ -118,7 +118,10 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): ...@@ -118,7 +118,10 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
s[data_pad].compute_inline() s[data_pad].compute_inline()
# schedule data packing # schedule data packing
_, h, w, ci, vh, vw = s[data_vec].op.axis if isinstance(data_vec.op, tvm.tensor.ComputeOp) and data_vec.op.name == 'data_vec_undilated':
_, h, w, ci, _, _, vh, vw = s[data_vec].op.axis
else:
_, h, w, ci, vh, vw = s[data_vec].op.axis
tile_and_bind3d(s, data_vec, h, w, ci, 1) tile_and_bind3d(s, data_vec, h, w, ci, 1)
if vh.dom.extent.value < max_unroll: if vh.dom.extent.value < max_unroll:
s[data_vec].unroll(vh) s[data_vec].unroll(vh)
......
...@@ -6,7 +6,6 @@ from collections import namedtuple ...@@ -6,7 +6,6 @@ from collections import namedtuple
import numpy as np import numpy as np
import tvm import tvm
from .dilate import dilate
from .pad import pad from .pad import pad
from .util import get_pad_tuple from .util import get_pad_tuple
from ..util import simplify, const_matrix, get_const_tuple from ..util import simplify, const_matrix, get_const_tuple
...@@ -128,17 +127,16 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): ...@@ -128,17 +127,16 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
else: else:
dilation_h, dilation_w = dilation dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
Filter = dilate(Filter, (1, 1, dilation_h, dilation_w))
batch, in_channel, in_height, in_width = Input.shape batch, in_channel, in_height, in_width = Input.shape
num_filter, channel, kernel_h, kernel_w = Filter.shape num_filter, channel, kernel_h, kernel_w = Filter.shape
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w))
# compute the output shape # compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter out_channel = num_filter
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# compute graph # compute graph
pad_before = [0, 0, pad_top, pad_left] pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right] pad_after = [0, 0, pad_down, pad_right]
...@@ -150,7 +148,8 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): ...@@ -150,7 +148,8 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
return tvm.compute( return tvm.compute(
(batch, out_channel, out_height, out_width), (batch, out_channel, out_height, out_width),
lambda nn, ff, yy, xx: tvm.sum( lambda nn, ff, yy, xx: tvm.sum(
temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) * temp[nn, rc, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w].astype(out_dtype) *
Filter[ff, rc, ry, rx].astype(out_dtype), Filter[ff, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx]), tag="conv2d_nchw") axis=[rc, ry, rx]), tag="conv2d_nchw")
...@@ -195,17 +194,16 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None): ...@@ -195,17 +194,16 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
else: else:
dilation_h, dilation_w = dilation dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
Filter = dilate(Filter, (dilation_h, dilation_w, 1, 1))
in_height, in_width, in_channel, batch = Input.shape in_height, in_width, in_channel, batch = Input.shape
kernel_h, kernel_w, channel, num_filter = Filter.shape kernel_h, kernel_w, channel, num_filter = Filter.shape
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w))
# compute the output shape # compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter out_channel = num_filter
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
pad_before = [pad_top, pad_left, 0, 0] pad_before = [pad_top, pad_left, 0, 0]
pad_after = [pad_down, pad_right, 0, 0] pad_after = [pad_down, pad_right, 0, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
...@@ -215,7 +213,8 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None): ...@@ -215,7 +213,8 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
Output = tvm.compute( Output = tvm.compute(
(out_height, out_width, out_channel, batch), (out_height, out_width, out_channel, batch),
lambda yy, xx, ff, nn: tvm.sum( lambda yy, xx, ff, nn: tvm.sum(
PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn].astype(out_dtype) * PaddedInput[yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w,
rc, nn].astype(out_dtype) *
Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2dOutput", tag="conv2d_hwcn") name="Conv2dOutput", tag="conv2d_hwcn")
return Output return Output
...@@ -259,17 +258,16 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): ...@@ -259,17 +258,16 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
else: else:
dilation_h, dilation_w = dilation dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
Filter = dilate(Filter, (dilation_h, dilation_w, 1, 1))
batch, in_height, in_width, in_channel = Input.shape batch, in_height, in_width, in_channel = Input.shape
kernel_h, kernel_w, channel, num_filter = Filter.shape kernel_h, kernel_w, channel, num_filter = Filter.shape
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w))
# compute the output shape # compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter out_channel = num_filter
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
pad_before = [0, pad_top, pad_left, 0] pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0] pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
...@@ -279,7 +277,8 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): ...@@ -279,7 +277,8 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
Output = tvm.compute( Output = tvm.compute(
(batch, out_height, out_width, out_channel), (batch, out_height, out_width, out_channel),
lambda nn, yy, xx, ff: tvm.sum( lambda nn, yy, xx, ff: tvm.sum(
PaddedInput[nn, yy * stride_h + ry, xx * stride_w + rx, rc].astype(out_dtype) * PaddedInput[nn, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2dOutput", tag="conv2d_nhwc") name="Conv2dOutput", tag="conv2d_nhwc")
return Output return Output
......
...@@ -72,18 +72,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No ...@@ -72,18 +72,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
else: else:
dilation_h, dilation_w = dilation dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
Filter = dilate(Filter, (1, 1, dilation_h, dilation_w))
batch, in_channel, in_height, in_width = Input.shape batch, in_channel, in_height, in_width = Input.shape
# shape of dilated kernel # shape of dilated kernel
filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape
dilated_kernel_h = (filter_height - 1) * dilation_h + 1
dilated_kernel_w = (filter_width - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple( pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (filter_height, filter_width)) padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = simplify(in_channel * channel_multiplier) out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1) out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# padding stage # padding stage
pad_before = [0, 0, pad_top, pad_left] pad_before = [0, 0, pad_top, pad_left]
...@@ -95,7 +94,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No ...@@ -95,7 +94,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
Output = tvm.compute( Output = tvm.compute(
(batch, out_channel, out_height, out_width), (batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum( lambda b, c, i, j: tvm.sum(
(PaddedInput[b, c/channel_multiplier, i*stride_h+di, j*stride_w+dj].astype(out_dtype) * (PaddedInput[b, c/channel_multiplier, i*stride_h+di*dilation_h,
j*stride_w+dj*dilation_w].astype(out_dtype) *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)), Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)),
axis=[di, dj]), axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw") name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
...@@ -143,18 +143,17 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No ...@@ -143,18 +143,17 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
else: else:
dilation_h, dilation_w = dilation dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
Filter = dilate(Filter, (dilation_h, dilation_w, 1, 1))
batch, in_height, in_width, in_channel = Input.shape batch, in_height, in_width, in_channel = Input.shape
# shape of dilated kernel # shape of dilated kernel
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
dilated_kernel_h = (filter_height - 1) * dilation_h + 1
dilated_kernel_w = (filter_width - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple( pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (filter_height, filter_width)) padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = simplify(in_channel * channel_multiplier) out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1) out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# padding stage # padding stage
pad_before = [0, pad_top, pad_left, 0] pad_before = [0, pad_top, pad_left, 0]
...@@ -166,8 +165,8 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No ...@@ -166,8 +165,8 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
Output = tvm.compute( Output = tvm.compute(
(batch, out_height, out_width, out_channel), (batch, out_height, out_width, out_channel),
lambda b, i, j, c: tvm.sum( lambda b, i, j, c: tvm.sum(
(PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier].astype( (PaddedInput[b, i*stride_h + di*dilation_h, j*stride_w + dj*dilation_w,
out_dtype) * c/channel_multiplier].astype(out_dtype) *
Filter[di, dj, c/channel_multiplier, c%channel_multiplier].astype(out_dtype)), Filter[di, dj, c/channel_multiplier, c%channel_multiplier].astype(out_dtype)),
axis=[di, dj]), axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc") name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
......
...@@ -9,7 +9,6 @@ from .. import nn ...@@ -9,7 +9,6 @@ from .. import nn
from ..util import get_const_tuple from ..util import get_const_tuple
from ..nn.conv2d import conv2d, conv2d_NCHWc, \ from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_alter_layout, _get_workload as _get_conv2d_workload conv2d_alter_layout, _get_workload as _get_conv2d_workload
from ..nn.dilate import dilate
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
from ..nn.pad import pad from ..nn.pad import pad
...@@ -89,9 +88,6 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout ...@@ -89,9 +88,6 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
else: else:
dilation_h, dilation_w = dilation dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
kernel = dilate(kernel, (1, 1, dilation_h, dilation_w))
HPAD, WPAD = padding HPAD, WPAD = padding
HSTR, WSTR = strides HSTR, WSTR = strides
...@@ -101,8 +97,10 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout ...@@ -101,8 +97,10 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
pad_height = in_height + 2 * HPAD pad_height = in_height + 2 * HPAD
pad_width = in_width + 2 * WPAD pad_width = in_width + 2 * WPAD
out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1 dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
out_height = (in_height + 2 * HPAD - dilated_kernel_h) // HSTR + 1
out_width = (in_width + 2 * WPAD - dilated_kernel_w) // WSTR + 1
# pack data # pack data
DOPAD = (HPAD != 0 or WPAD != 0) DOPAD = (HPAD != 0 or WPAD != 0)
...@@ -136,8 +134,8 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout ...@@ -136,8 +134,8 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
kw = tvm.reduce_axis((0, kernel_width), name='kw') kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh, ic%ic_bn, tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh*dilation_h, ic%ic_bn,
ow*WSTR+kw].astype(out_dtype) * ow*WSTR+kw*dilation_w].astype(out_dtype) *
kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn,
oc_block].astype(out_dtype), oc_block].astype(out_dtype),
axis=[ic, kh, kw]), name='conv') axis=[ic, kh, kw]), name='conv')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment