Commit 5712ea6b by Yizhi Liu Committed by Tianqi Chen

[TOPI] depthwise-conv2d in NCHW[x]c layout for x86 (#2045)

parent b88065a8
......@@ -4,7 +4,7 @@ from __future__ import absolute_import
import tvm
import topi
from topi.util import get_const_int
from topi.util import get_const_int, get_const_tuple
from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg
from .registry import OpPattern
......@@ -164,16 +164,22 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
out_channel = attrs.get_int("channels")
groups = attrs.get_int("groups")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
out_dtype = attrs.get_string("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
_, in_channel_chunk, _, _, in_channel_block = get_const_tuple(inputs[0].shape)
in_channel = in_channel_chunk * in_channel_block
assert dilation == (1, 1), "not support dilate now"
if groups == 1:
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
layout, out_layout, out_dtype)
elif groups == in_channel and groups == out_channel:
out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding,
dilation, layout, out_layout, out_dtype)
# pylint: enable=assignment-from-no-return
else:
raise ValueError("not support arbitrary group number > 1 for now")
......@@ -187,9 +193,12 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of conv2d NCHWc"""
groups = attrs.get_int("groups")
out_channel = attrs.get_int("channels")
with tvm.target.create(target):
if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(outs)
elif groups == out_channel:
return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)
else:
raise ValueError("not support group number > 1 for now")
......
......@@ -82,7 +82,9 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
wshape[kernel_layout.indexof('O')] *= param.groups;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
if (in_shape->at(Conv2DParam::kWeight).ndim() == 0) {
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
}
if (param.use_bias) {
static const Layout default_bias_layout("C");
TShape bias_shape({param.channels});
......
......@@ -175,6 +175,23 @@ def schedule_depthwise_conv2d_nhwc(outs):
@tvm.target.generic_func
def schedule_depthwise_conv2d_NCHWc(outs):
"""Schedule for depthwise_conv2d_NCHWc
Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d_nhwc
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_group_conv2d_nchw(outs):
"""Schedule for conv2d_nchw
......
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=invalid-name, unused-variable, too-many-locals, unused-argument
"""Depthwise convolution operators"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from .dilate import dilate
......@@ -8,6 +9,27 @@ from .pad import pad
from .util import get_pad_tuple
from ..util import simplify
# workload description of depthwise-conv2d
Workload = namedtuple('Workload',
['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
def _get_workload(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """
_, in_channel, height, width = [x.value for x in data.shape]
channel, channel_multiplier, kh, kw = [x.value for x in kernel.shape]
out_channel = channel * channel_multiplier
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, height, width, in_channel,
out_channel, kh, kw, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func
def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
......@@ -258,3 +280,44 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid
tag='depthwise_conv2d_backward_weight_nhwc')
return Weight_grad
@tvm.target.generic_func
def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
layout, out_layout, out_dtype=None):
"""Depthwise convolution NCHW[x]c forward operator.
Parameters
----------
Input : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
Filter : tvm.Tensor
4-D with shape [out_channel_chunk, filter_height, filter_width, out_channel_block]
In NCHWc depthwise convolution,
we group kernel's in_channel and channel_multiplier together then do the tiling.
stride : tuple of two ints
The spatial stride along height and width
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype: str, optional
Output data type
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
......@@ -9,3 +9,4 @@ from .nn import *
from .injective import *
from .pooling import schedule_pool, schedule_global_pool
from .bitserial_conv2d import schedule_bitserial_conv2d
from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc
......@@ -7,36 +7,30 @@ from tvm.autotvm.task import get_config
from .. import generic, tag
from .. import nn
from ..util import get_const_tuple
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, _get_workload
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_alter_layout, _get_workload as _get_conv2d_workload
from ..nn.dilate import dilate
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
from ..nn.pad import pad
from . import conv2d_avx_1x1, conv2d_avx_common
def _get_fp32_len():
fp32_vec_len = 8
target = tvm.target.current_target()
if target is not None:
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
fp32_vec_len = 16
return fp32_vec_len
def _get_default_config(cfg, workload):
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False):
"""
Get default schedule config for the workload
Parameters
----------
workload : topi.nn.conv2d.Workload
Convolution workload
"""
fp32_vec_len = _get_fp32_len()
is_kernel_1x1 = workload.hkernel == 1 and workload.wkernel == 1
if is_kernel_1x1:
conv2d_avx_1x1._fallback_schedule(cfg, workload, fp32_vec_len)
if is_depthwise:
wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
from depthwise_conv2d import _fallback_schedule
_fallback_schedule(cfg, wkl)
else:
conv2d_avx_common._fallback_schedule(cfg, workload, fp32_vec_len)
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
if is_kernel_1x1:
conv2d_avx_1x1._fallback_schedule(cfg, wkl)
else:
conv2d_avx_common._fallback_schedule(cfg, wkl)
def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
......@@ -74,10 +68,9 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out
if layout == 'NCHW':
_create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout)
if cfg.is_fallback:
wkl = _get_workload(data, kernel, strides, padding, out_dtype)
_get_default_config(cfg, wkl)
return _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout,
out_dtype)
_get_default_config(cfg, data, kernel, strides, padding, out_dtype)
return _declaration_conv_impl(cfg, data, kernel, strides,
padding, dilation, layout, out_dtype)
elif layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
elif layout == 'NHWC':
......@@ -295,44 +288,69 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
copy_inputs = [s for s in inputs]
new_attrs = {k : attrs[k] for k in attrs.keys()}
data, kernel = tinfo[0], tinfo[1]
# only optimize for NCHW, groups=1 conv
if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1:
return None
batch_size, in_channel, height, width = get_const_tuple(data.shape)
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
groups = attrs.get_int("groups")
out_channel = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
layout = attrs['layout']
kh, kw = attrs.get_int_tuple("kernel_size")
dtype = data.dtype
out_dtype = dtype if attrs["out_dtype"] == "same" else attrs["out_dtype"]
is_depthwise = groups == in_channel and groups == out_channel
# only optimize for NCHW
if layout != 'NCHW':
return None
if groups != 1 and not is_depthwise:
return None
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, layout, out_dtype], conv2d)
dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target()
# query schedule and fallback if necessary
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, out_dtype], depthwise_conv2d_nchw) \
if is_depthwise else \
autotvm.task.args_to_workload(
[data, kernel, strides, padding, layout, out_dtype], conv2d)
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback:
wkl = _get_workload(data, kernel, strides, padding, out_dtype)
_get_default_config(cfg, wkl)
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
new_attrs['layout'] = 'NCHW%dc' % ic_bn
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
# Store the same config for the altered operator (workload)
new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
dtype=data.dtype)
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, new_attrs['layout'],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
if is_depthwise:
# channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block
# in which out_channel = merge(channel, channel_multiplier)
kernel_sym = copy_inputs[1]
kernel_sym = sym.reshape(kernel_sym, shape=(out_channel//oc_bn, oc_bn, kh, kw))
kernel_sym = sym.transpose(kernel_sym, axes=(0, 2, 3, 1))
copy_inputs[1] = kernel_sym
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, new_attrs['layout'],
new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc)
else:
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, new_attrs['layout'],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
......@@ -354,13 +372,11 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
# get workload and related schedule config
wkl = _get_workload(tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
dtype=kernel.dtype),
strides, padding, out_dtype)
if cfg.is_fallback:
_get_default_config(cfg, wkl)
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
dtype=kernel.dtype),
strides, padding, out_dtype)
# output shape
out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1
......@@ -386,7 +402,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
n_elems = 4
assert ic_bn % n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//ic_bn), name='ic_outer')
ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
......
......@@ -8,8 +8,10 @@ from ..nn.util import infer_pad
from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
from .util import get_fp32_len
def _fallback_schedule(cfg, wkl, simd_width):
def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
......
......@@ -8,8 +8,10 @@ from ..nn.util import infer_pad
from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
from .util import get_fp32_len
def _fallback_schedule(cfg, wkl, simd_width):
def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
......
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Depthwise Conv2D schedule on x86"""
import tvm
from tvm import autotvm
from tvm.autotvm.task import get_config
from tvm.autotvm.task.space import SplitEntity
from tvm.autotvm.task.nnvm_integration import deserialize_args
from .. import generic, tag
from ..nn.pad import pad
from ..util import get_const_tuple
from ..nn.util import get_pad_tuple
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload
from .util import get_fp32_len
def _fallback_schedule(cfg, wkl):
"""
Get default schedule for the workload
Parameters
----------
cfg : tvm.autotvm.task.space.FallbackConfigEntity
Fallback config to be updated
wkl : topi.nn.depthwise_conv2d.Workload
Convolution workload
"""
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if wkl.out_filter % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if wkl.in_filter % bn == 0:
ic_bn = bn
break
reg_n = 1
for n in range(31, 0, -1):
if out_width % n == 0:
reg_n = n
break
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
@autotvm.register_topi_compute(depthwise_conv2d_NCHWc, 'cpu', 'direct')
def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
layout, out_layout, out_dtype=None):
out_dtype = data.dtype if out_dtype is None else out_dtype
batch, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape)
out_channel_chunk, filter_height, filter_width, out_channel_block \
= get_const_tuple(kernel.shape)
strides = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HSTR, WSTR = strides
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (filter_height, filter_width))
dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
assert (dh, dw) == (1, 1), "Does not support dilation"
in_channel = in_channel_chunk * in_channel_block
out_channel = out_channel_chunk * out_channel_block
channel_multiplier = out_channel // in_channel
out_height = (in_height - filter_height + pad_top + pad_down) // HSTR + 1
out_width = (in_width - filter_width + pad_left + pad_right) // WSTR + 1
# get workload and related schedule config
wkl = _get_workload(tvm.placeholder((batch, in_channel, in_height, in_width), dtype=data.dtype),
tvm.placeholder((out_channel, in_channel, filter_height, filter_width),
dtype=kernel.dtype),
strides, padding, out_dtype)
if cfg.is_fallback:
_fallback_schedule(cfg, wkl)
# padding stage
DOPAD = (pad_top != 0 or pad_left != 0 or pad_down != 0 or pad_right != 0)
if DOPAD:
pad_before = [0, 0, pad_top, pad_left, 0]
pad_after = [0, 0, pad_down, pad_right, 0]
data_pad = pad(data, pad_before, pad_after, name="PaddedInput")
else:
data_pad = data
# depthconv stage
kh = tvm.reduce_axis((0, filter_height), name='kh')
kw = tvm.reduce_axis((0, filter_width), name='kw')
Output = tvm.compute(
(batch, out_channel_chunk, out_height, out_width, out_channel_block),
lambda b, oco, oh, ow, oci: tvm.sum(
(data_pad[b, (oco * out_channel_block + oci) // channel_multiplier // in_channel_block,
oh*HSTR+kh, ow*WSTR+kw,
((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block]
.astype(out_dtype) *
kernel[oco, kh, kw, oci].astype(out_dtype)),
axis=[kh, kw]),
name='DepthwiseConv2d', tag="depthwise_conv2d_NCHWc")
return Output
@autotvm.register_topi_schedule(generic.schedule_depthwise_conv2d_NCHWc, 'cpu', ['direct'])
def schedule_depthwise_conv2d_NCHWc(cfg, outs):
"""CPU schedule for depthwise conv2d in NCHW[x]c layout"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'depthwise_conv2d_NCHWc' in op.tag:
conv_out = op.output(0)
data = conv_out.op.input_tensors[0]
kernel = conv_out.op.input_tensors[1]
_schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data, kernel, conv_out, outs[0])
scheduled_ops.append(op)
traverse(outs[0].op)
return s
def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data, kernel, conv_out, output):
tile_ow = cfg["tile_ow"].size[-1]
# schedule data
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
p = s[A].fuse(ic_chunk, ih)
s[A].parallel(p)
C, O = conv_out, output
CC = s.cache_write(C, 'global')
_, ic_chunk, oh, ow, ic_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=tile_ow)
s[C].reorder(ic_chunk, oh, ow_chunk, ow_block, ic_block)
parallel_axis = s[C].fuse(ic_chunk, oh)
s[C].parallel(parallel_axis)
s[CC].compute_at(s[C], ow_chunk)
_, ic_chunk, oh, ow, ic_block = s[CC].op.axis
kh, kw = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=tile_ow)
s[CC].reorder(ic_chunk, oh, kh, kw, ow_block, ic_block)
s[CC].vectorize(ic_block)
s[CC].unroll(ow_block)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=tile_ow)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
@autotvm.task.register("topi_x86_depthwise_conv2d_NCHWc_from_nchw")
def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
data, kernel, strides, padding, dilation, dtype = deserialize_args(args)
batch, in_channel, height, width = get_const_tuple(data.shape)
filter_channel, channel_multiplier, kh, kw = get_const_tuple(kernel.shape)
ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding)
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
out_height = (height - kh + 2 * ph) // sh + 1
out_width = (width - kw + 2 * pw) // sw + 1
out_channel = filter_channel * channel_multiplier
# get config here
cfg = get_config()
cfg.define_split("tile_ic", in_channel, num_outputs=2)
cfg.define_split("tile_oc", out_channel, num_outputs=2)
cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
# change shape with the value in config
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
new_data_shape = (batch, in_channel // ic_bn, height, width, ic_bn)
new_kernel_shape = (out_channel // oc_bn, kh, kw, oc_bn)
new_data = tvm.placeholder(new_data_shape, data.dtype)
new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
data_layout = "NCHW%dc" % ic_bn
out_layout = "NCHW%dc" % oc_bn
C = _depthwise_conv2d_NCHWc_cpu(cfg, new_data, new_kernel, strides, padding, dilation,
data_layout, out_layout, dtype)
s = schedule_depthwise_conv2d_NCHWc(cfg, [C])
return s, [new_data, new_kernel, C]
"""Common x86 related utilities"""
from __future__ import absolute_import as _abs
import tvm
def get_fp32_len():
fp32_vec_len = 8
target = tvm.target.current_target()
if target is not None:
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
fp32_vec_len = 16
return fp32_vec_len
......@@ -13,27 +13,22 @@ from common import get_all_backend
def _transform_data(data, bn):
# NCHW -> NCHW[x]c
batch_size, channel, height, width = data.shape
data = np.transpose(data, (0, 2, 3, 1))
data = np.reshape(data, (batch_size, height, width, channel//bn, bn))
data = np.transpose(data, (0, 3, 1, 2, 4))
data = np.reshape(data, (batch_size, channel//bn, bn, height, width))
data = np.transpose(data, (0, 1, 3, 4, 2))
return data
def _transform_kernel(kernel, ic_bn, oc_bn):
# OIHW -> OIHW[x]i[x]o
out_channel, in_channel, kh, kw = kernel.shape
kernel = np.transpose(kernel, (1, 2, 3, 0))
kernel = np.reshape(kernel, (in_channel, kh, kw, out_channel//oc_bn, oc_bn))
kernel = np.transpose(kernel, (1, 2, 3, 4, 0))
kernel = np.reshape(kernel, (kh, kw, out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn))
kernel = np.transpose(kernel, (2, 4, 0, 1, 5, 3))
kernel = np.reshape(kernel, (out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn, kh, kw))
kernel = np.transpose(kernel, (0, 2, 4, 5, 3, 1))
return kernel
def _transform_bias(bias, bn):
# [num_filter, 1, 1] -> [num_filter//bn, 1, 1, bn]
num_filter, h, w = bias.shape
bias = np.transpose(bias, (1, 2, 0))
bias = np.reshape(bias, (h, w, num_filter//bn, bn))
bias = np.transpose(bias, (2, 0, 1, 3))
bias = np.reshape(bias, (num_filter//bn, bn, h, w))
bias = np.transpose(bias, (0, 2, 3, 1))
return bias
def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
......@@ -86,6 +81,7 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding),
(dilation, dilation),
layout='NCHW%dc'%ic_block,
out_layout="NCHW%dc"%oc_block,
out_dtype=dtype)
......@@ -117,7 +113,7 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
check_device(device)
if __name__ == "__main__":
def test_conv2d_NCHWc():
# ResNet18 workloads
verify_conv2d_NCHWc(1, 3, 224, 64, 7, 2, 3)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1)
......@@ -204,3 +200,6 @@ if __name__ == "__main__":
verify_conv2d_NCHWc(1, 2048, 10, 126, 3, 1, 1)
verify_conv2d_NCHWc(1, 512, 5, 126, 3, 1, 1)
verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1)
if __name__ == "__main__":
test_conv2d_NCHWc()
\ No newline at end of file
......@@ -203,6 +203,115 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
def _transform_data(data, bn):
# NCHW -> NCHW[x]c
batch_size, channel, height, width = data.shape
data = np.reshape(data, (batch_size, channel//bn, bn, height, width))
data = np.transpose(data, (0, 1, 3, 4, 2))
return data
def _transform_kernel(kernel, bn):
# channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block
channel, channel_multiplier, kh, kw = kernel.shape
out_channel = channel * channel_multiplier
kernel = np.reshape(kernel, (out_channel//bn, bn, kh, kw))
kernel = np.transpose(kernel, (0, 2, 3, 1))
return kernel
def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
in_width = in_height
filter_channel = in_channel
filter_width = filter_height
stride_h = stride_w = stride
assert dilation == 1, "depthwise_conv2d_NCHWc currently does not support dilation."
pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
padding_args = (pad_h, pad_w)
out_channel = filter_channel * channel_multiplier
# for testing functionality,
# we choose arbitrary block size that can divide the channel,
# regardless of the performance.
oc_block = 1
for bn in range(16, 0, -1):
if out_channel % bn == 0:
oc_block = bn
break
ic_block = 1
for bn in range(oc_block, 0, -1):
if in_channel % bn == 0:
ic_block = bn
break
# placeholder
Input = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='Input')
Filter = tvm.placeholder((out_channel//oc_block, filter_height, filter_width, oc_block), name='Filter')
in_layout = "NCHW%dc" % ic_block
out_layout = "NCHW%dc" % oc_block
dtype = 'float32'
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
# declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_NCHWc(Input, Filter,
(stride_h, stride_w),
padding_args,
(dilation, dilation),
in_layout,
out_layout, dtype)
# TODO: add scale_shift implement for NCHWc and add test here
Relu = topi.nn.relu(DepthwiseConv2d)
# schedule
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
# build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Relu], device)
# Prepare pod type for test data closure
input_shape = (batch, in_channel, in_height, in_width)
filter_shape = (filter_channel, channel_multiplier, filter_height, filter_width)
# Use memoize, pickle the test data for next time use.
@memoize("topi.tests.test_topi_depthwise_conv2d.NCHWc")
def get_ref_data():
input_np = np.random.uniform(size=input_shape).astype(dtype)
filter_np = np.random.uniform(size=filter_shape).astype(dtype)
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
input_np, filter_np, stride, padding)
relu_scipy = np.maximum(depthwise_conv2d_scipy, 0)
return (_transform_data(input_np, ic_block),
_transform_kernel(filter_np, oc_block),
_transform_data(depthwise_conv2d_scipy, oc_block),
_transform_data(relu_scipy, oc_block))
# Get the test data
(input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) = get_ref_data()
input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)
depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),
dtype=DepthwiseConv2d.dtype), ctx)
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
# launch kernel 1 (depthwise_conv2d)
f1(input_tvm, filter_tvm, depthwise_conv2d_tvm)
# launch kernel 2 (depthwise_conv2d + relu)
f2(input_tvm, filter_tvm, relu_tvm)
tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
# test llvm only for now since depthwise_conv2d_NCHWc implement is missing in other backend.
for device in ["llvm"]:
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
def test_depthwise_conv2d():
# mobilenet workloads
......@@ -233,5 +342,12 @@ def test_depthwise_conv2d():
# disabled because it uses too large shared memory on cuda
# depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2)
# NCHW[x]c
depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_NCHWc(4, 256, 64, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_NCHWc(4, 256, 64, 2, 5, 2, "VALID")
if __name__ == "__main__":
test_depthwise_conv2d()
......@@ -117,7 +117,15 @@ def tune_kernels(tasks,
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# converting conv2d tasks to conv2d_NCHWc tasks
task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=tsk.args,
op_name = tsk.workload[0]
if op_name == 'conv2d':
func_create = 'topi_x86_conv2d_NCHWc'
elif op_name == 'depthwise_conv2d_nchw':
func_create = 'topi_x86_depthwise_conv2d_NCHWc_from_nchw'
else:
raise ValueError("Tuning {} is not supported on x86".format(op_name))
task = autotvm.task.create(func_create, args=tsk.args,
target=target, template_key='direct')
task.workload = tsk.workload
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment