Commit 04fb5509 by Yizhi Liu Committed by Tianqi Chen

[TOPI] conv2d avx (#883)

* conv2d schedules for Intel CPU (AVX2 & AVX512)

* fix lint

* remove override register
parent ba7b9ddd
...@@ -11,7 +11,7 @@ src/llvm/* @aatluri ...@@ -11,7 +11,7 @@ src/llvm/* @aatluri
src/runtime/rocm/* @aatluri src/runtime/rocm/* @aatluri
# JVM language # JVM language
jvm/* @javelinjs jvm/* @yzhliu
# TOPI # TOPI
topi/python/topi/* @Laurawly @Huyuwei topi/python/topi/* @Laurawly @Huyuwei
...@@ -26,7 +26,7 @@ and are qualified to lead development and review changes of the owned module. ...@@ -26,7 +26,7 @@ and are qualified to lead development and review changes of the owned module.
- [Aditya Atluri](https://github.com/adityaatluri) ROCM - [Aditya Atluri](https://github.com/adityaatluri) ROCM
- [Leyuan Wang](https://github.com/Laurawly) TOPI - [Leyuan Wang](https://github.com/Laurawly) TOPI
- [Yuwei Hu](https://github.com/Huyuwei) TOPI - [Yuwei Hu](https://github.com/Huyuwei) TOPI
- [Yizhi Liu](https://github.com/javelinjs) JVM package - [Yizhi Liu](https://github.com/yzhliu) JVM package
List of Contributors List of Contributors
-------------------- --------------------
......
# pylint: disable=invalid-name,unused-variable,invalid-name # pylint: disable=invalid-name,unused-variable,invalid-name
"""Conv2D schedule on x86""" """Conv2D schedule on x86"""
import tvm import tvm
from .. import generic from .. import generic, tag
from .. import tag from .. import nn
from ..nn.util import infer_pad, infer_stride
from ..nn.conv2d import conv2d, _get_workload, _get_schedule, _WORKLOADS
from . import conv2d_avx_1x1, conv2d_avx_common
from .conv2d_avx_common import AVXConvCommonFwd
from .conv2d_avx_1x1 import AVXConv1x1Fwd
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv
}
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv
}
@_get_schedule.register("cpu")
def _get_schedule_conv(wkl):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
fp32_vec_len = 8
target = tvm.target.current_target(allow_none=False)
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
fp32_vec_len = 16
_SCHEDULES_AVX_NCHW = [
# float32 resnet-18
AVXConvCommonFwd(3, fp32_vec_len, 28, False),
AVXConvCommonFwd(16, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConvCommonFwd(16, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConvCommonFwd(16, fp32_vec_len, 28, False),
AVXConvCommonFwd(16, fp32_vec_len, 14, False),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14),
AVXConvCommonFwd(16, fp32_vec_len, 14, True),
AVXConvCommonFwd(16, 32, 7, True),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7),
AVXConvCommonFwd(16, fp32_vec_len, 7, True),
# float32 mobilenet
AVXConvCommonFwd(3, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7),
]
sch = _SCHEDULES_AVX_NCHW[idx]
return sch
@conv2d.register("cpu")
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
target = tvm.target.current_target(allow_none=False)
if 'avx' in str(target) and layout == 'NCHW':
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype)
elif layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, stride, padding, out_dtype)
elif layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, stride, padding, out_dtype)
elif layout == 'NHWC':
return nn.conv2d_nhwc(data, kernel, stride, padding, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout))
@generic.schedule_conv2d_nchw.register(["cpu"]) @generic.schedule_conv2d_nchw.register(["cpu"])
def schedule_conv2d(outs): def schedule_conv2d(outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target(allow_none=False)
def traverse(op): def traverse(op):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
...@@ -16,7 +93,7 @@ def schedule_conv2d(outs): ...@@ -16,7 +93,7 @@ def schedule_conv2d(outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
else: # inject custom schedule else: # inject custom schedule
if len(op.axis) == 4: # schedule bias + bn + relu if len(op.axis) == 4 and 'avx' not in str(target): # schedule bias + bn + relu
n, c, h, w = op.axis n, c, h, w = op.axis
fused = s[op].fuse(n, c) fused = s[op].fuse(n, c)
s[op].parallel(fused) s[op].parallel(fused)
...@@ -26,27 +103,50 @@ def schedule_conv2d(outs): ...@@ -26,27 +103,50 @@ def schedule_conv2d(outs):
traverse(tensor.op) traverse(tensor.op)
if 'conv2d_nchw' in op.tag: if 'conv2d_nchw' in op.tag:
conv = op.output(0) if 'avx' in str(target):
kernel = op.input_tensors[1] output = op.output(0)
data = op.input_tensors[0] conv_out = op.input_tensors[0]
data_pad = None kernel_vec = conv_out.op.input_tensors[1]
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: kernel = kernel_vec.op.input_tensors[0]
data_pad = data data_vec = conv_out.op.input_tensors[0]
data = data_pad.op.input_tensors[0] data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, c_pad, h_pad, w_pad = data_pad.op.axis padding = infer_pad(data, data_pad)
pad_fused = s[data_pad].fuse(n_pad, c_pad) if data_pad is None:
s[data_pad].parallel(pad_fused) stride = infer_stride(data, kernel, output)
C = conv else:
n, c, h, w = C.op.axis stride = infer_stride(data_pad, kernel, output)
rc, ry, rx = C.op.reduce_axis
fused = s[C].fuse(n, c) wkl = _get_workload(data, kernel, stride, padding, output.dtype)
s[C].parallel(fused) sch = _get_schedule(wkl)
wo, wi = s[C].split(w, factor=16) _AVX_SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec,
s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop kernel, kernel_vec, conv_out, output, outs[0])
s[C].unroll(rx) else:
s[C].unroll(ry) conv = op.output(0)
s[C].vectorize(wi) kernel = op.input_tensors[1]
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, c_pad, h_pad, w_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(n_pad, c_pad)
s[data_pad].parallel(pad_fused)
C = conv
n, c, h, w = C.op.axis
rc, ry, rx = C.op.reduce_axis
fused = s[C].fuse(n, c)
s[C].parallel(fused)
wo, wi = s[C].split(w, factor=16)
s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop
s[C].unroll(rx)
s[C].unroll(ry)
s[C].vectorize(wi)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
......
# pylint: disable=invalid-name,unused-variable,invalid-name
"""1x1 Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from ..util import get_const_tuple
from ..nn.conv2d import _get_schedule, _get_workload
from ..nn.util import infer_pad, infer_stride
from ..nn.pad import pad
AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor'])
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution for AVX"
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape)
num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
pad_height = in_height + 2 * HPAD
pad_width = in_width + 2 * WPAD
out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1
out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1
DOPAD = (HPAD != 0 and WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
else:
data_pad = data
shape = (batch_size, in_channel // sch.ic_bn, pad_height, pad_width, sch.ic_bn)
data_vec = tvm.compute(shape, lambda n, C, h, w, c: data_pad[n, C * sch.ic_bn + c, h, w])
shape = (num_filter // sch.oc_bn, in_channel // sch.ic_bn, sch.ic_bn, sch.oc_bn, 1, 1)
kernel_vec = tvm.compute(shape, lambda CO, CI, ci, co, h, w:
kernel[CO * sch.oc_bn + co, CI * sch.ic_bn + ci, h, w],
name='kernel_vec')
oshape = (batch_size, num_filter // sch.oc_bn, out_height, out_width, sch.oc_bn)
ic = tvm.reduce_axis((0, in_channel), name='ic')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_vec[n, ic//sch.ic_bn, oh*HSTR, ow*WSTR, ic%sch.ic_bn] *
kernel_vec[oc_chunk, ic//sch.ic_bn, ic%sch.ic_bn, oc_block, 0, 0],
axis=[ic]), name='conv')
oshape = (batch_size, num_filter, out_height, out_width)
unpack = tvm.compute(oshape, lambda n, oc, oh, ow:
conv[n, oc // sch.oc_bn, oh, ow, oc % sch.oc_bn],
tag='conv2d_nchw')
return unpack
def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
HPAD, WPAD = wkl.hpad, wkl.wpad
DOPAD = (HPAD != 0 and WPAD != 0)
A, W = data, kernel_vec
A0, A1 = data_pad, data_vec
# schedule data
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(ic_chunk, ih)
s[A1].parallel(parallel_axis)
s[A1].pragma(batch, "parallel_launch_point")
s[A1].pragma(parallel_axis, "parallel_stride_pattern")
s[A1].pragma(batch, "parallel_barrier_when_finish")
# schedule kernel pack
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
if sch.oc_bn > 1:
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, oh)
s[W].parallel(parallel_axis)
s[W].pragma(parallel_axis, "parallel_launch_point")
s[W].pragma(parallel_axis, "parallel_stride_pattern")
s[W].pragma(parallel_axis, "parallel_barrier_when_finish")
C, O0, O = conv_out, output, last
CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor)
s[C].vectorize(oc_block)
s[CC].compute_at(s[C], oh_outer)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic, = s[CC].op.reduce_axis
ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn)
oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor)
s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_inner)
s[CC].unroll(oh_inner)
if O0 != O:
s[O0].compute_inline()
batch, oc, oh, ow = s[O].op.axis
oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn)
oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
s[O].pragma(batch, "parallel_launch_point")
s[O].pragma(parallel_axis, "parallel_stride_pattern")
s[O].pragma(batch, "parallel_barrier_when_finish")
return s
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from ..util import get_const_tuple
from ..nn.conv2d import _get_schedule, _get_workload
from ..nn.util import infer_pad, infer_stride
from ..nn.pad import pad
AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw'])
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution for AVX"
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape)
num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
pad_height = in_height + 2 * HPAD
pad_width = in_width + 2 * WPAD
out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1
out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1
# pack data
DOPAD = (HPAD != 0 and WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
else:
data_pad = data
shape = (batch_size, in_channel // sch.ic_bn, pad_height, sch.ic_bn, pad_width)
data_vec = tvm.compute(shape,
lambda n, C, h, c, w: data_pad[n, C * sch.ic_bn + c, h, w],
name='data_vec')
# pack kernel
shape = (num_filter//sch.oc_bn, in_channel//sch.ic_bn,
kernel_height, kernel_width, sch.ic_bn, sch.oc_bn)
kernel_vec = tvm.compute(shape, lambda CO, CI, h, w, ci, co:
kernel[CO * sch.oc_bn + co, CI * sch.ic_bn + ci, h, w],
name='kernel_vec')
# convolution
oshape = (batch_size, num_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
unpack_shape = (batch_size, num_filter, out_height, out_width)
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_vec[n, ic//sch.ic_bn, oh*HSTR+kh, ic%sch.ic_bn, ow*WSTR+kw] *
kernel_vec[oc_chunk, ic//sch.ic_bn, kh, kw, ic%sch.ic_bn, oc_block],
axis=[ic, kh, kw]),
name='conv')
unpack = tvm.compute(unpack_shape,
lambda n, c, h, w: conv[n, c // sch.oc_bn, h, w, c % sch.oc_bn],
name='output_unpack',
tag='conv2d_nchw')
return unpack
def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
HPAD, WPAD = wkl.hpad, wkl.wpad
DOPAD = (HPAD != 0 and WPAD != 0)
A, W = data, kernel_vec
A0, A1 = data_pad, data_vec
# schedule data
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(ic_chunk, ih)
s[A1].parallel(parallel_axis)
s[A1].pragma(batch, "parallel_launch_point")
s[A1].pragma(parallel_axis, "parallel_stride_pattern")
s[A1].pragma(batch, "parallel_barrier_when_finish")
# schedule kernel pack
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
if sch.oc_bn > 1:
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, oh)
s[W].parallel(parallel_axis)
s[W].pragma(parallel_axis, "parallel_launch_point")
s[W].pragma(parallel_axis, "parallel_stride_pattern")
s[W].pragma(parallel_axis, "parallel_barrier_when_finish")
# schedule conv
C, O0, O = conv_out, output, last
CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
s[C].fuse(oc_chunk, oh)
s[C].vectorize(oc_block)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic, kh, kw = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n)
ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn)
if sch.unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block)
s[CC].fuse(oc_chunk, oh)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_block)
if O0 != O:
s[O0].compute_inline()
batch, oc, oh, ow = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
s[O].pragma(batch, "parallel_launch_point")
s[O].pragma(parallel_axis, "parallel_stride_pattern")
s[O].pragma(batch, "parallel_barrier_when_finish")
return s
...@@ -11,8 +11,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -11,8 +11,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d_nchw(A, W, stride, padding)
C = topi.nn.relu(B)
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
...@@ -35,6 +33,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -35,6 +33,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
B = topi.nn.conv2d(A, W, stride, padding, layout='NCHW')
C = topi.nn.relu(B)
s1 = topi.generic.schedule_conv2d_nchw([B]) s1 = topi.generic.schedule_conv2d_nchw([B])
s2 = topi.generic.schedule_conv2d_nchw([C]) s2 = topi.generic.schedule_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment