Commit 04fb5509 by Yizhi Liu Committed by Tianqi Chen

[TOPI] conv2d avx (#883)

* conv2d schedules for Intel CPU (AVX2 & AVX512)

* fix lint

* remove override register
parent ba7b9ddd
......@@ -11,7 +11,7 @@ src/llvm/* @aatluri
src/runtime/rocm/* @aatluri
# JVM language
jvm/* @javelinjs
jvm/* @yzhliu
# TOPI
topi/python/topi/* @Laurawly @Huyuwei
......@@ -26,7 +26,7 @@ and are qualified to lead development and review changes of the owned module.
- [Aditya Atluri](https://github.com/adityaatluri) ROCM
- [Leyuan Wang](https://github.com/Laurawly) TOPI
- [Yuwei Hu](https://github.com/Huyuwei) TOPI
- [Yizhi Liu](https://github.com/javelinjs) JVM package
- [Yizhi Liu](https://github.com/yzhliu) JVM package
List of Contributors
--------------------
......
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Conv2D schedule on x86"""
import tvm
from .. import generic
from .. import tag
from .. import generic, tag
from .. import nn
from ..nn.util import infer_pad, infer_stride
from ..nn.conv2d import conv2d, _get_workload, _get_schedule, _WORKLOADS
from . import conv2d_avx_1x1, conv2d_avx_common
from .conv2d_avx_common import AVXConvCommonFwd
from .conv2d_avx_1x1 import AVXConv1x1Fwd
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv
}
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv
}
@_get_schedule.register("cpu")
def _get_schedule_conv(wkl):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
fp32_vec_len = 8
target = tvm.target.current_target(allow_none=False)
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
fp32_vec_len = 16
_SCHEDULES_AVX_NCHW = [
# float32 resnet-18
AVXConvCommonFwd(3, fp32_vec_len, 28, False),
AVXConvCommonFwd(16, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConvCommonFwd(16, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConvCommonFwd(16, fp32_vec_len, 28, False),
AVXConvCommonFwd(16, fp32_vec_len, 14, False),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14),
AVXConvCommonFwd(16, fp32_vec_len, 14, True),
AVXConvCommonFwd(16, 32, 7, True),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7),
AVXConvCommonFwd(16, fp32_vec_len, 7, True),
# float32 mobilenet
AVXConvCommonFwd(3, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7),
]
sch = _SCHEDULES_AVX_NCHW[idx]
return sch
@conv2d.register("cpu")
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
target = tvm.target.current_target(allow_none=False)
if 'avx' in str(target) and layout == 'NCHW':
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype)
elif layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, stride, padding, out_dtype)
elif layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, stride, padding, out_dtype)
elif layout == 'NHWC':
return nn.conv2d_nhwc(data, kernel, stride, padding, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout))
@generic.schedule_conv2d_nchw.register(["cpu"])
def schedule_conv2d(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target(allow_none=False)
def traverse(op):
"""Traverse operators from computation graph"""
......@@ -16,7 +93,7 @@ def schedule_conv2d(outs):
if op not in s.outputs:
s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 4: # schedule bias + bn + relu
if len(op.axis) == 4 and 'avx' not in str(target): # schedule bias + bn + relu
n, c, h, w = op.axis
fused = s[op].fuse(n, c)
s[op].parallel(fused)
......@@ -26,27 +103,50 @@ def schedule_conv2d(outs):
traverse(tensor.op)
if 'conv2d_nchw' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
if 'avx' in str(target):
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, c_pad, h_pad, w_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(n_pad, c_pad)
s[data_pad].parallel(pad_fused)
C = conv
n, c, h, w = C.op.axis
rc, ry, rx = C.op.reduce_axis
fused = s[C].fuse(n, c)
s[C].parallel(fused)
wo, wi = s[C].split(w, factor=16)
s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop
s[C].unroll(rx)
s[C].unroll(ry)
s[C].vectorize(wi)
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec,
kernel, kernel_vec, conv_out, output, outs[0])
else:
conv = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, c_pad, h_pad, w_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(n_pad, c_pad)
s[data_pad].parallel(pad_fused)
C = conv
n, c, h, w = C.op.axis
rc, ry, rx = C.op.reduce_axis
fused = s[C].fuse(n, c)
s[C].parallel(fused)
wo, wi = s[C].split(w, factor=16)
s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop
s[C].unroll(rx)
s[C].unroll(ry)
s[C].vectorize(wi)
traverse(outs[0].op)
return s
......
# pylint: disable=invalid-name,unused-variable,invalid-name
"""1x1 Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from ..util import get_const_tuple
from ..nn.conv2d import _get_schedule, _get_workload
from ..nn.util import infer_pad, infer_stride
from ..nn.pad import pad
AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor'])
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution for AVX"
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape)
num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
pad_height = in_height + 2 * HPAD
pad_width = in_width + 2 * WPAD
out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1
out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1
DOPAD = (HPAD != 0 and WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
else:
data_pad = data
shape = (batch_size, in_channel // sch.ic_bn, pad_height, pad_width, sch.ic_bn)
data_vec = tvm.compute(shape, lambda n, C, h, w, c: data_pad[n, C * sch.ic_bn + c, h, w])
shape = (num_filter // sch.oc_bn, in_channel // sch.ic_bn, sch.ic_bn, sch.oc_bn, 1, 1)
kernel_vec = tvm.compute(shape, lambda CO, CI, ci, co, h, w:
kernel[CO * sch.oc_bn + co, CI * sch.ic_bn + ci, h, w],
name='kernel_vec')
oshape = (batch_size, num_filter // sch.oc_bn, out_height, out_width, sch.oc_bn)
ic = tvm.reduce_axis((0, in_channel), name='ic')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_vec[n, ic//sch.ic_bn, oh*HSTR, ow*WSTR, ic%sch.ic_bn] *
kernel_vec[oc_chunk, ic//sch.ic_bn, ic%sch.ic_bn, oc_block, 0, 0],
axis=[ic]), name='conv')
oshape = (batch_size, num_filter, out_height, out_width)
unpack = tvm.compute(oshape, lambda n, oc, oh, ow:
conv[n, oc // sch.oc_bn, oh, ow, oc % sch.oc_bn],
tag='conv2d_nchw')
return unpack
def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
HPAD, WPAD = wkl.hpad, wkl.wpad
DOPAD = (HPAD != 0 and WPAD != 0)
A, W = data, kernel_vec
A0, A1 = data_pad, data_vec
# schedule data
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(ic_chunk, ih)
s[A1].parallel(parallel_axis)
s[A1].pragma(batch, "parallel_launch_point")
s[A1].pragma(parallel_axis, "parallel_stride_pattern")
s[A1].pragma(batch, "parallel_barrier_when_finish")
# schedule kernel pack
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
if sch.oc_bn > 1:
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, oh)
s[W].parallel(parallel_axis)
s[W].pragma(parallel_axis, "parallel_launch_point")
s[W].pragma(parallel_axis, "parallel_stride_pattern")
s[W].pragma(parallel_axis, "parallel_barrier_when_finish")
C, O0, O = conv_out, output, last
CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor)
s[C].vectorize(oc_block)
s[CC].compute_at(s[C], oh_outer)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic, = s[CC].op.reduce_axis
ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn)
oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor)
s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_inner)
s[CC].unroll(oh_inner)
if O0 != O:
s[O0].compute_inline()
batch, oc, oh, ow = s[O].op.axis
oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn)
oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
s[O].pragma(batch, "parallel_launch_point")
s[O].pragma(parallel_axis, "parallel_stride_pattern")
s[O].pragma(batch, "parallel_barrier_when_finish")
return s
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from ..util import get_const_tuple
from ..nn.conv2d import _get_schedule, _get_workload
from ..nn.util import infer_pad, infer_stride
from ..nn.pad import pad
AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw'])
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution for AVX"
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape)
num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
pad_height = in_height + 2 * HPAD
pad_width = in_width + 2 * WPAD
out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1
out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1
# pack data
DOPAD = (HPAD != 0 and WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
else:
data_pad = data
shape = (batch_size, in_channel // sch.ic_bn, pad_height, sch.ic_bn, pad_width)
data_vec = tvm.compute(shape,
lambda n, C, h, c, w: data_pad[n, C * sch.ic_bn + c, h, w],
name='data_vec')
# pack kernel
shape = (num_filter//sch.oc_bn, in_channel//sch.ic_bn,
kernel_height, kernel_width, sch.ic_bn, sch.oc_bn)
kernel_vec = tvm.compute(shape, lambda CO, CI, h, w, ci, co:
kernel[CO * sch.oc_bn + co, CI * sch.ic_bn + ci, h, w],
name='kernel_vec')
# convolution
oshape = (batch_size, num_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
unpack_shape = (batch_size, num_filter, out_height, out_width)
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_vec[n, ic//sch.ic_bn, oh*HSTR+kh, ic%sch.ic_bn, ow*WSTR+kw] *
kernel_vec[oc_chunk, ic//sch.ic_bn, kh, kw, ic%sch.ic_bn, oc_block],
axis=[ic, kh, kw]),
name='conv')
unpack = tvm.compute(unpack_shape,
lambda n, c, h, w: conv[n, c // sch.oc_bn, h, w, c % sch.oc_bn],
name='output_unpack',
tag='conv2d_nchw')
return unpack
def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
HPAD, WPAD = wkl.hpad, wkl.wpad
DOPAD = (HPAD != 0 and WPAD != 0)
A, W = data, kernel_vec
A0, A1 = data_pad, data_vec
# schedule data
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(ic_chunk, ih)
s[A1].parallel(parallel_axis)
s[A1].pragma(batch, "parallel_launch_point")
s[A1].pragma(parallel_axis, "parallel_stride_pattern")
s[A1].pragma(batch, "parallel_barrier_when_finish")
# schedule kernel pack
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
if sch.oc_bn > 1:
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, oh)
s[W].parallel(parallel_axis)
s[W].pragma(parallel_axis, "parallel_launch_point")
s[W].pragma(parallel_axis, "parallel_stride_pattern")
s[W].pragma(parallel_axis, "parallel_barrier_when_finish")
# schedule conv
C, O0, O = conv_out, output, last
CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
s[C].fuse(oc_chunk, oh)
s[C].vectorize(oc_block)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic, kh, kw = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n)
ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn)
if sch.unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block)
s[CC].fuse(oc_chunk, oh)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_block)
if O0 != O:
s[O0].compute_inline()
batch, oc, oh, ow = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
s[O].pragma(batch, "parallel_launch_point")
s[O].pragma(parallel_axis, "parallel_stride_pattern")
s[O].pragma(batch, "parallel_barrier_when_finish")
return s
......@@ -11,8 +11,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d_nchw(A, W, stride, padding)
C = topi.nn.relu(B)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
......@@ -35,6 +33,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
return
print("Running on target: %s" % device)
with tvm.target.create(device):
B = topi.nn.conv2d(A, W, stride, padding, layout='NCHW')
C = topi.nn.relu(B)
s1 = topi.generic.schedule_conv2d_nchw([B])
s2 = topi.generic.schedule_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment