Commit a7f01253 by ziheng Committed by Tianqi Chen

[TOPI] Update depthwise conv2d schedule on rasp (#500)

parent 9e7a6674
# pylint: disable=invalid-name,unused-variable
"""Schedule for depthwise_conv2d with auto fusion"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from .. import tag
from ..nn.util import infer_pad, infer_stride, get_pad_tuple
_Workload = namedtuple('Workload',
['height', 'width', 'channel', 'multiplier',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
_Schedule = namedtuple('Schedule', ['vh', 'vw', 'vc', 'bc', 'unroll'])
# workloads of depthwise conv mobile net on imagenet
_WORKLOADS = [
_Workload(112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
_Workload(112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
_Workload( 56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
_Workload( 56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
_Workload( 28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
_Workload( 28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
_Workload( 14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
_Workload( 14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
_Workload( 14, 14, 1024, 1, 3, 3, 1, 1, 1, 1),
]
_SCHEDULES = [
_Schedule(2, 1, 4, 1, True),
_Schedule(2, 4, 4, 2, True),
_Schedule(2, 1, 4, 2, False),
_Schedule(2, 4, 4, 1, True),
_Schedule(4, 1, 4, 8, True),
_Schedule(1, 1, 4, 2, True),
_Schedule(1, 1, 8, 8, True),
_Schedule(1, 1, 4, 1, False),
_Schedule(2, 1, 4, 16, False),
]
def _get_workload(data, kernel, stride, padding):
_, C, IH, IW = [x.value for x in data.shape]
_, MT, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
return _Workload(IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _schedule(s, data, data_pad, kernel, output, last):
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)
if wkl not in _WORKLOADS:
return s
# use specified schedule
sch = _SCHEDULES[_WORKLOADS.index(wkl)]
H, W = wkl.height, wkl.width
CN = wkl.channel
MT = wkl.multiplier
HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
VH, VW = sch.vh, sch.vw
BC = sch.bc
VC = sch.vc
TH = H + 2*HPAD
TW = W + 2*WPAD
OH = (H + 2*HPAD - HK) / HSTR + 1
OW = (W + 2*WPAD - WK) / WSTR + 1
A, B, C = data, kernel, output
A0 = data_pad
C0 = last
A1 = s.cache_read(A0, "global", C)
_, c, h, w = s[A1].op.axis
c, vc = s[A1].split(c, VC)
s[A1].reorder(c, h, w, vc)
A2 = s.cache_write(A1, 'global')
s[A0].compute_inline()
s[A1].compute_inline()
B0 = s.cache_read(B, "global", C)
c, m, h, w = s[B0].op.axis
c, vc = s[B0].split(c, VC)
s[B0].reorder(c, m, h, w, vc)
B1 = s.cache_write(B0, 'global')
s[B0].compute_inline()
_, c, h, w = s[C].op.axis
dh, dw = s[C].op.reduce_axis
c, vc = s[C].split(c, VC)
s[C].reorder(c, h, w, vc)
C0 = s.cache_write(C, 'global')
_, c, h, w, vc = s[C0].op.axis
dh, dw = s[C0].op.reduce_axis
oh, ow, ih, iw = s[C0].tile(h, w, VH, VW)
s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc)
if sch.unroll:
s[C0].unroll(iw)
s[C0].vectorize(vc)
oh, ow, ih, iw = s[C].tile(h, w, 2, 4)
s[C].reorder(oh, ow, dh, dw, ih, iw)
s[C].unroll(ih)
s[C].vectorize(iw)
# # s[C0].compute_at(s[C0], ow)
launch, c, _, _ = s[C].op.axis
s[C].pragma(launch, "parallel_launch_point")
s[C].parallel(c)
s[C].pragma(c, "parallel_launch_point")
s[C].pragma(c, "parallel_stride_pattern")
s[C].pragma(c, "parallel_barrier_when_finish")
s[C0].compute_at(s[C], launch)
_, c, h, w, vc = s[C0].op.axis
s[C0].parallel(c)
s[C0].pragma(c, "parallel_stride_pattern")
s[C0].pragma(c, "parallel_barrier_when_finish")
s[A2].compute_at(s[C0], oh)
# parallel(s[A2], s[A2].op.axis[1], BC)
# # s[B0].compute_at(s[C0], ow)
s[B1].compute_at(s[C], launch)
c, m, h, w, vc = s[B1].op.axis
s[B1].parallel(c)
s[B1].pragma(c, "parallel_stride_pattern")
s[B1].pragma(c, "parallel_barrier_when_finish")
return s
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment