Commit 3df42cd7 by masahi Committed by Tianqi Chen

[TOPI] Basic x86 schedules (#775)

* add basic x86 schedules

* parallelize & vectorize batchnorm + relu

* fuse conv into bn + relu

* move rc loop to outer

* add nhwc conv

* change weight layout to hwcf

* conv + bn + relu fusion for nhwc conv

* fix conv_nhwc schedule when no fusion

* clean up default parallel schedules

* simplify elemwise parallel

* fix elemwise parallel for batch == 1

* update nhwc conv test

* fix and add comment

* fix lint

* remove redundant import

* remove default multithreading for some ops

* remove default multithreading for global pool
parent 7ca44d7a
...@@ -36,6 +36,24 @@ def schedule_conv2d_nchw(outs): ...@@ -36,6 +36,24 @@ def schedule_conv2d_nchw(outs):
@tvm.target.generic_func @tvm.target.generic_func
def schedule_conv2d_nhwc(outs):
"""Schedule for conv2d_nhwc
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs): def schedule_conv2d_transpose_nchw(outs):
"""Schedule for conv2d_transpose_nchw """Schedule for conv2d_transpose_nchw
......
...@@ -337,6 +337,57 @@ def conv2d_hwcn(Input, Filter, stride, padding, out_dtype='float32'): ...@@ -337,6 +337,57 @@ def conv2d_hwcn(Input, Filter, stride, padding, out_dtype='float32'):
name="Conv2dOutput", tag="conv2d_hwcn") name="Conv2dOutput", tag="conv2d_hwcn")
return Output return Output
def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
"""Convolution operator in NHWC layout.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
Filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
assert isinstance(stride, int) or len(stride) == 2
batch, in_height, in_width, in_channel = Input.shape
kernel_h, kernel_w, channel, num_filter = Filter.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w))
# compute the output shape
out_channel = num_filter
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
Output = tvm.compute(
(batch, out_height, out_width, out_channel),
lambda nn, yy, xx, ff: tvm.sum(
PaddedInput[nn, yy * stride_h + ry, xx * stride_w + rx, rc].astype(out_dtype) *
Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2dOutput", tag="conv2d_nhwc")
return Output
# map from schedule type to declaration function # map from schedule type to declaration function
_SCH_TO_DECL_FUNC = { _SCH_TO_DECL_FUNC = {
SpatialPack: _spatial_pack, SpatialPack: _spatial_pack,
......
...@@ -6,6 +6,7 @@ from __future__ import absolute_import as _abs ...@@ -6,6 +6,7 @@ from __future__ import absolute_import as _abs
from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nchw_python import conv2d_nchw_python
from .conv2d_nhwc_python import conv2d_nhwc_python
from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python from .dilate_python import dilate_python
......
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Convolution in python"""
import numpy as np
import scipy.signal
def conv2d_nhwc_python(a_np, w_np, stride, padding):
"""Convolution operator in NHWC layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_height, in_width, in_channel]
w_np : numpy.ndarray
4-D with shape [num_filter, filter_height, filter_width, in_channel]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
b_np : np.ndarray
4-D with shape [out_height, out_width, out_channel, batch]
"""
batch, in_height, in_width, in_channel = a_np.shape
kernel_h, kernel_w, _, num_filter = w_np.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == 'VALID':
pad_h = 0
pad_w = 0
else: # 'SAME'
pad_h = kernel_h - 1
pad_w = kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
# compute the output shape
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
# change the layout from NHWC to NCHW
at = a_np.transpose((0, 3, 1, 2))
wt = w_np.transpose((3, 2, 0, 1))
bt = np.zeros((batch, out_channel, out_height, out_width))
# computation
for n in range(batch):
for f in range(out_channel):
for c in range(in_channel):
if pad_h > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w))
apad[pad_top:-pad_bottom, pad_left:-pad_right] = at[n, c]
else:
apad = at[n, c]
out = scipy.signal.convolve2d(
apad, np.rot90(np.rot90(wt[f, c])), mode='valid')
bt[n, f] += out[::stride, ::stride]
return bt.transpose((0, 2, 3, 1))
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
"""x86 specific declaration and schedules.""" """x86 specific declaration and schedules."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .conv2d import schedule_conv2d from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
from .binarize_pack import schedule_binarize_pack from .binarize_pack import schedule_binarize_pack
from .binary_dense import schedule_binary_dense from .binary_dense import schedule_binary_dense
from .nn import *
from .injective import *
...@@ -15,6 +15,12 @@ def schedule_conv2d(outs): ...@@ -15,6 +15,12 @@ def schedule_conv2d(outs):
if tag.is_broadcast(op.tag): if tag.is_broadcast(op.tag):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 4: # schedule bias + bn + relu
n, c, h, w = op.axis
fused = s[op].fuse(n, c)
s[op].parallel(fused)
s[op].vectorize(w)
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors:
traverse(tensor.op) traverse(tensor.op)
...@@ -28,10 +34,68 @@ def schedule_conv2d(outs): ...@@ -28,10 +34,68 @@ def schedule_conv2d(outs):
data_pad = data data_pad = data
data = data_pad.op.input_tensors[0] data = data_pad.op.input_tensors[0]
n_pad, c_pad, h_pad, w_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(n_pad, c_pad)
s[data_pad].parallel(pad_fused)
C = conv C = conv
n, c, h, w = C.op.axis n, c, h, w = C.op.axis
s[C].parallel(c) rc, ry, rx = C.op.reduce_axis
s[C].pragma(n, "parallel_launch_point") fused = s[C].fuse(n, c)
s[C].parallel(fused)
wo, wi = s[C].split(w, factor=16)
s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop
s[C].unroll(rx)
s[C].unroll(ry)
s[C].vectorize(wi)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
@generic.schedule_conv2d_nhwc.register(["cpu"])
def schedule_conv2d_nhwc(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 4: # schedule bias + bn + relu
n, h, w, c = op.axis
fused = s[op].fuse(n, h, w)
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'conv2d_nhwc' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, h_pad, w_pad, c_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(n_pad, h_pad)
s[data_pad].parallel(pad_fused)
C = conv
n, h, w, c = C.op.axis
ry, rx, rc = C.op.reduce_axis
n_out, h_out, w_out, c_out = output_op.axis
s[C].vectorize(c)
if op != output_op: # fuse bias + bn + relu into conv
s[C].compute_at(s[output_op], c_out)
else:
fused = s[C].fuse(n, h, w)
s[C].parallel(fused)
traverse(output_op)
return s
# pylint: disable=invalid-name
"""x86 declaration and schedules."""
from __future__ import absolute_import as _abs
import tvm
from .. import generic
@generic.schedule_injective.register(["cpu"])
def schedule_injective(outs):
"""X86 schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of injective in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) == 4:
n, c, _, _ = s[x].op.axis
fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h
s[x].parallel(fused)
else:
s[x].parallel(s[x].op.axis[0])
return s
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
"""x86 nn operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import generic
def _default_schedule(outs, auto_inline):
"""Default schedule for x86."""
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
if auto_inline:
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
if len(s[x].op.axis) == 4:
n, c, _, _ = s[x].op.axis
fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h
s[x].parallel(fused)
else:
s[x].parallel(s[x].op.axis[0])
return s
@generic.schedule_softmax.register(["cpu"])
def schedule_softmax(outs):
"""Schedule for softmax
Parameters
----------
outs: Array of Tensor
The computation graph description of softmax
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@generic.schedule_pool.register(["cpu"])
def schedule_pool(outs):
"""Schedule for pool
Parameters
----------
outs: Array of Tensor
The computation graph description of pool
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
"""Example code to do convolution."""
import os
import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size
A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_nhwc(A, W, stride, padding)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_nhwc.verify_nhwc")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_conv2d_nhwc([B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], device)
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm']:
check_device(device)
def test_conv2d_nhwc():
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "SAME")
verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "SAME")
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID")
verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID")
if __name__ == "__main__":
test_conv2d_nhwc()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment