Commit e05f54be by ziheng Committed by GitHub

[TOPI] Add topi.target; Schedule for raspberry pi (#406)

* CPU Schedule for raspberry pi

* Update

* Update

* Add topi.target

* Refactor

* Update

* Make python3 happy

* Improve

* Improve

* Improve

* Use get_const_int
parent f6bb7aba
......@@ -14,5 +14,7 @@ from .reduction import *
from .broadcast import *
from . import nn
from . import cuda
from . import rasp
from . import target
from . import testing
from . import util
......@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from .batch_norm import *
from .convolution import *
from .depthwise_convolution import *
from .elemwise import *
from .dilate import *
from .flatten import *
......
# pylint: disable=invalid-name, unused-variable, too-many-locals
"""Depthwise Convolution operators"""
from __future__ import absolute_import as _abs
import tvm
from .pad import pad
from .util import get_pad_tuple
from ..util import simplify
def depthwise_conv2d_nchw(Input, Filter, stride, padding):
"""Depthwise convolution nchw forward operator.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.Tensor
4-D with shape [in_channel, channel_multiplier, filter_height, filter_width]
stride : tuple of two ints
The spatial stride along height and width
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = Input.shape
filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (filter_height, filter_width))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)
# padding stage
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum(
(PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj]),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
return Output
def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
"""Depthwise convolution nhwc forward operator.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
Filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
Stride : tvm.Tensor
1-D of size 2
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
batch, in_height, in_width, in_channel = Input.shape
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (filter_height, filter_width))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)
# padding stage
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_height, out_width, out_channel),
lambda b, i, j, c: tvm.sum(
(PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier] *
Filter[di, dj, c/channel_multiplier, c%channel_multiplier]),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
return Output
......@@ -3,51 +3,6 @@ from __future__ import absolute_import as _abs
import tvm
from ..util import equal_const_int
def _spatial2d_pad_option(padding, kernel):
"""Common code to get the pad option
Parameters
----------
padding : int or str
Padding size, or ['VALID', 'SAME']
kernel : tuple of int
Conv kernel size
Returns
-------
pad_top : int
Padding size on top
pad_left : int
Padding size on left
pad_down : int
Padding size on down.
pad_right : int
Padding size on right.
"""
# compute the padding size
if isinstance(padding, (tuple, list)):
pad_h = padding[0] * 2
pad_w = padding[1] * 2
elif isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == "VALID":
pad_h = 0
pad_w = 0
elif padding == "SAME":
pad_h = kernel[0] - 1
pad_w = kernel[1] - 1
else:
raise ValueError("Unknown padding option %s" % padding)
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left
@tvm.tag_scope(tag="pad")
def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
"""Dilate Input with zeros.
......
"""TVM operator pooling compute."""
from __future__ import absolute_import
import tvm
from .pad import pad
from .util import get_pad_tuple
from .. import util
from .pad import pad, _spatial2d_pad_option
def max_pool(data, kernel, stride, padding):
"""Perform max pooling on the data
......@@ -32,7 +33,7 @@ def max_pool(data, kernel, stride, padding):
stride_height, stride_width = stride
batch, channel, height, width = data.shape
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_height, kernel_width))
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
......
# pylint: disable=invalid-name, unused-variable
"""NN operator common utilities"""
from __future__ import absolute_import
from ..util import get_const_int
def infer_pad(data, data_pad):
"""Infer the padding from stages in reverse.
Parameters
----------
data : Tensor
data stage.
data_pad : Tensor
pad stage.
Returns
-------
hpad : int
padding size on height
wpad : int
padding size on width
"""
if data_pad is None:
return 0, 0
_, _, IH, IW = data.shape
_, _, TH, TW = data_pad.shape
hpad = (TH - IH) // 2
wpad = (TW - IW) // 2
return get_const_int(hpad), get_const_int(wpad)
def infer_stride(data, kernel, out):
"""Infer the stride from stages in reverse.
Parameters
----------
data : Tensor
data stage.
kernel : Tensor
kernel stage.
out : Tensor
output stage.
Returns
-------
hstride : int
stride size on height
wstride : int
stride size on width
"""
_, _, IH, IW = data.shape
_, _, KH, KW = kernel.shape
_, _, OH, OW = out.shape
hstride = (IH - KH) // (OH - 1)
wstride = (IW - KW) // (OW - 1)
return get_const_int(hstride), get_const_int(wstride)
def get_pad_tuple(padding, kernel):
"""Common code to get the pad option
Parameters
----------
padding : int or str
Padding size, or ['VALID', 'SAME']
kernel : tuple of int
Conv kernel size
Returns
-------
pad_top : int
Padding size on top
pad_left : int
Padding size on left
pad_down : int
Padding size on down.
pad_right : int
Padding size on right.
"""
# compute the padding size
if isinstance(padding, (tuple, list)):
pad_h = padding[0] * 2
pad_w = padding[1] * 2
elif isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == "VALID":
pad_h = 0
pad_w = 0
elif padding == "SAME":
pad_h = kernel[0] - 1
pad_w = kernel[1] - 1
else:
raise ValueError("Unknown padding option %s" % padding)
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left
# pylint: disable=redefined-builtin, wildcard-import
"""Raspberry pi specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .convolution import *
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Convolution schedule on raspberry pi"""
from __future__ import absolute_import as _abs
import tvm
from .. import target as _target
from ..nn.convolution import SpatialPack, Im2ColPack
from ..nn.convolution import _CONV_DECLARATION, _CONV_SCHEDULE
from ..nn.convolution import _WORKLOADS, _SCH_TO_DECL_FUNC
from ..nn.convolution import _get_workload, _get_schedule
from ..nn.util import infer_pad, infer_stride
_SCHEDULES = [
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 7, 4, 2, 4, True),
SpatialPack(1, 4, 8, 4, 1, True),
SpatialPack(1, 4, 4, 1, 16, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(1, 7, 4, 3, 8, True),
SpatialPack(1, 2, 8, 1, 8, True),
SpatialPack(2, 1, 16, 1, 4, True),
SpatialPack(1, 7, 4, 1, 1, True),
Im2ColPack(7, 4, 1, 16, True),
Im2ColPack(7, 4, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
]
def _schedule_conv2d(wkl):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
sch = _SCHEDULES[idx]
return sch
_CONV_SCHEDULE[_target.rasp()] = _schedule_conv2d
def _declaration_conv2d(data, kernel, stride, padding, layout):
assert layout == 'NCHW', "only support NCHW convolution on rasp"
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding)
_CONV_DECLARATION[_target.rasp()] = _declaration_conv2d
def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec,
conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl, 'rasp')
H, W = wkl.height, wkl.width
CI, CO = wkl.in_filter, wkl.out_filter
HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
HCAT, WCAT = HK-1, WK-1
DOPAD = (HPAD != 0 and WPAD != 0)
VH = sch.vh
VW = sch.vw
VC = sch.vc
UNROLL = sch.unroll
A, B, C = data, kernel, last
A0, A1 = data_pad, data_vec
B0 = kernel_vec
C0, C1 = conv_out, output
CC = s.cache_write(C0, "global")
_, co, oh, ow, vh, vw, vc = s[C0].op.axis
if UNROLL:
s[C0].unroll(vw)
s[C0].vectorize(vc)
s[CC].compute_at(s[C0], ow)
_, co, oh, ow, vh, vw, vc = s[CC].op.axis
ci, dh, dw = s[CC].op.reduce_axis
s[CC].reorder(ci, dh, vh, dw, vw, vc)
if UNROLL:
s[CC].unroll(vw)
s[CC].vectorize(vc)
##### Schedule A
if DOPAD:
s[A0].compute_inline()
_, h, _, _, _, _ = s[A1].op.axis
if sch.ba == 1:
oaxis = h
paxis = h
else:
oh, ih = s[A1].split(h, sch.ba)
oaxis = oh
paxis = ih
s[A1].parallel(paxis)
s[A1].pragma(oaxis, "parallel_launch_point")
s[A1].pragma(paxis, "parallel_stride_pattern")
s[A1].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B
co, _, _, _, _ = s[B0].op.axis
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[B0].split(co, sch.bc)
oaxis = oco
paxis = ico
s[B0].parallel(paxis)
s[B0].pragma(oaxis, "parallel_launch_point")
s[B0].pragma(paxis, "parallel_stride_pattern")
s[B0].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule C
n, co, h, w = s[C].op.axis
co, vc = s[C].split(co, VC)
oh, ow, vh, vw = s[C].tile(h, w, VH, VW)
s[C].reorder(n, co, oh, ow, vh, vw, vc)
if C != C1:
s[C1].compute_inline()
s[C0].compute_at(s[C], ow)
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[C].split(co, sch.bc)
oaxis = oco
paxis = ico
s[C].parallel(paxis)
s[C].pragma(oaxis, "parallel_launch_point")
s[C].pragma(paxis, "parallel_stride_pattern")
s[C].pragma(oaxis, "parallel_barrier_when_finish")
return s
def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
kernel, kernel_vec,
conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl, 'rasp')
H, W = wkl.height, wkl.width
CI = wkl.in_filter
CO = wkl.out_filter
HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
HCAT, WCAT = HK-1, WK-1
DOPAD = (HPAD != 0 and WPAD != 0)
P = sch.vp
Q = sch.vq
UNROLL = sch.unroll
A, B, C = data, kernel, last
A0, A1, A2 = data_pad, data_col, data_vec
B0 = kernel_vec
C0, C1 = conv_out, output
CC = s.cache_write(C0, "global")
AA = s.cache_read(A2, "global", [CC])
BB = s.cache_read(B0, "global", [CC])
##### Schedule CC
_, co, im, vim, vco = s[C0].op.axis
s[C0].unroll(vim)
s[C0].vectorize(vco)
s[CC].compute_at(s[C0], im)
_, co, im, vim, vco = s[CC].op.axis
ci, hk, wk = s[CC].op.reduce_axis
s[CC].reorder(ci, hk, wk, vim, vco)
s[CC].unroll(vim)
s[CC].vectorize(vco)
# s[CC].unroll(ccr)
### Schedule C
_, co, h, w = s[C].op.axis
im = s[C].fuse(h, w)
im, vim = s[C].split(im, P)
co, vco = s[C].split(co, Q)
s[C].reorder(co, im, vim, vco)
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[C].split(co, sch.bc)
oaxis = oco
paxis = ico
s[C].parallel(paxis)
s[C].pragma(oaxis, "parallel_launch_point")
s[C].pragma(paxis, "parallel_stride_pattern")
s[C].pragma(oaxis, "parallel_barrier_when_finish")
if C1 != C:
s[C1].compute_inline()
s[C0].compute_at(s[C], paxis)
##### Schedule A
if DOPAD:
s[A0].compute_inline()
s[A1].compute_inline()
s[AA].compute_at(s[CC], wk)
s[AA].unroll(AA.op.axis[4])
_, im, _, _, _, _ = s[A2].op.axis
if sch.ba == 1:
oaxis = im
paxis = im
else:
oim, iim = s[A2].split(im, sch.ba)
oaxis = oim
paxis = iim
s[A2].parallel(paxis)
s[A2].pragma(oaxis, "parallel_launch_point")
s[A2].pragma(paxis, "parallel_stride_pattern")
s[A2].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B
s[BB].compute_at(s[CC], wk)
s[BB].vectorize(BB.op.axis[4])
co, _, _, _, _ = s[B0].op.axis
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[B0].split(co, sch.bc)
oaxis = oco
paxis = ico
s[B0].parallel(paxis)
s[B0].pragma(oaxis, "parallel_launch_point")
s[B0].pragma(paxis, "parallel_stride_pattern")
s[B0].pragma(oaxis, "parallel_barrier_when_finish")
return s
def schedule_convolution(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if 'ewise' in op.tag or 'bcast' in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'spatial_conv_output' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec,
conv_out, output, outs[0])
if 'im2col_conv_output' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data_col = data_vec.op.input_tensors[0]
data = data_col.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
kernel, kernel_vec,
conv_out, output, outs[0])
traverse(outs[0].op)
return s
"""Target management API of topi"""
from __future__ import absolute_import
class Target(object):
"""A Target describes the target type on which computation should be carried on"""
default_target = None
str2type = {'x86': 1, 'cuda': 2, 'rasp': 3}
type2str = {1: 'x86', 2: 'cuda', 3: 'rasp'}
def __init__(self, target_type):
"""Constructs a context."""
if isinstance(target_type, Target):
self.target_typeid = target_type.target_typeid
else:
self.target_typeid = Target.str2type[target_type]
@property
def target_type(self):
"""Returns the target type of current target."""
return Target.type2str[self.target_typeid]
def __hash__(self):
"""Compute hash value of target for dictionary lookup"""
return hash(self.target_typeid)
def __eq__(self, other):
"""Compares two targets. Two targets are equal if they
have the same target type.
"""
return isinstance(other, Target) and \
self.target_typeid == other.target_typeid
def __str__(self):
return '%s' % (self.target_type)
def __repr__(self):
return self.__str__()
def __enter__(self):
self._old_target = Target.default_target
Target.default_target = self
return self
def __exit__(self, ptype, value, trace):
Target.default_target = self._old_target
Target.default_target = Target('x86')
def x86():
"""Returns a x86 target."""
return Target('x86')
def cuda():
"""Returns a cuda target."""
return Target('cuda')
def rasp():
"""Returns a rasp target."""
return Target('rasp')
def current_target():
"""Returns the current target."""
return Target.default_target
"""Example code to do convolution."""
import os
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def verify_convolution(batch, in_size, in_channel, num_filter, kernel, stride, padding):
in_height = in_width = in_size
with topi.target.rasp():
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.convolution(A, W, stride, padding)
s = topi.rasp.schedule_convolution([B])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], "llvm")
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_convolution():
verify_convolution(1, 56, 64, 64, 3, 1, 1)
if __name__ == "__main__":
test_convolution()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment