Unverified Commit 0cfdecda by Yao Wang Committed by GitHub

Fix intel conv2d auto tune (#5200)

* Fix x86 conv2d and depthwise conv2d auto tuning

* Fix depthwise conv2d infer layout

* Use random data instead of empty data for autotvm

* Fix pylint

* Keep empty array for now for autotvm
parent b41f4e55
...@@ -185,7 +185,19 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layo ...@@ -185,7 +185,19 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layo
# Pack data if raw 4-D data is provided. # Pack data if raw 4-D data is provided.
# This can only happen when autotuning. # This can only happen when autotuning.
if len(data.shape) == 4: if len(data.shape) == 4:
data, kernel = _pack_data(cfg, data, kernel) if autotvm.GLOBAL_SCOPE.in_tuning:
# Directly use modified data layout placeholder.
dshape = (n, in_channel // cfg["tile_ic"].size[-1],
ih, iw, cfg["tile_ic"].size[-1])
data = tvm.te.placeholder(dshape, data.dtype, name="data")
kshape = (num_filter // cfg["tile_oc"].size[-1],
in_channel // cfg["tile_ic"].size[-1],
kernel_height, kernel_width,
cfg["tile_ic"].size[-1],
cfg["tile_oc"].size[-1])
kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel")
else:
data, kernel = _pack_data(cfg, data, kernel)
return nn.conv2d_NCHWc(data, return nn.conv2d_NCHWc(data,
kernel, kernel,
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import te from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.pad import pad from ..nn.pad import pad
...@@ -69,17 +68,12 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last): ...@@ -69,17 +68,12 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
if isinstance(s[data_vec].op, tvm.te.ComputeOp) \ if isinstance(s[data_vec].op, tvm.te.ComputeOp) \
and "pad" in data_vec.op.tag: and "pad" in data_vec.op.tag:
batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
s[data_vec].vectorize(ic_block)
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis) s[data_vec].parallel(parallel_axis)
data_vec = data_vec.op.input_tensors[0] data_vec = data_vec.op.input_tensors[0]
if autotvm.GLOBAL_SCOPE.in_tuning: if isinstance(kernel_vec.op, tvm.te.ComputeOp) and \
# only in autotuning, input data of conv2d_NCHWc will be 4-D.
# skip this part during tuning to make records accurate.
# this part will be folded during Relay fold_constant pass.
s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region")
s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region")
elif isinstance(kernel_vec.op, tvm.te.ComputeOp) and \
kernel_vec.name == 'kernel_vec': kernel_vec.name == 'kernel_vec':
# data and kernel are not pre-computed, schedule layout transform here. # data and kernel are not pre-computed, schedule layout transform here.
# this should only be used by x86 conv2d_nchw, which is for # this should only be used by x86 conv2d_nchw, which is for
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name # pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
"""Conv2D schedule on for Intel CPU""" """Conv2D schedule on for Intel CPU"""
import tvm import tvm
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..generic import conv2d as conv2d_generic from ..generic import conv2d as conv2d_generic
...@@ -91,17 +90,12 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last): ...@@ -91,17 +90,12 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
if isinstance(s[data_vec].op, tvm.te.ComputeOp) \ if isinstance(s[data_vec].op, tvm.te.ComputeOp) \
and "pad" in data_vec.op.tag: and "pad" in data_vec.op.tag:
batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
s[data_vec].vectorize(ic_block)
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis) s[data_vec].parallel(parallel_axis)
data_vec = data_vec.op.input_tensors[0] data_vec = data_vec.op.input_tensors[0]
if autotvm.GLOBAL_SCOPE.in_tuning: if isinstance(kernel_vec.op, tvm.te.ComputeOp) and \
# only in autotuning, input data of conv2d_NCHWc will be 4-D.
# skip this part during tuning to make records accurate.
# this part will be folded during Relay fold_constant pass.
s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region")
s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region")
elif isinstance(kernel_vec.op, tvm.te.ComputeOp) and \
kernel_vec.name == 'kernel_vec': kernel_vec.name == 'kernel_vec':
# data and kernel are not pre-computed, schedule layout transform here. # data and kernel are not pre-computed, schedule layout transform here.
# this should only be used by x86 conv2d_nchw, which is for # this should only be used by x86 conv2d_nchw, which is for
......
...@@ -43,7 +43,6 @@ def _fallback_schedule(cfg, wkl): ...@@ -43,7 +43,6 @@ def _fallback_schedule(cfg, wkl):
HPAD, WPAD = wkl.hpad, wkl.wpad HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
oc_bn = 1 oc_bn = 1
...@@ -148,10 +147,21 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, ...@@ -148,10 +147,21 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
# Pack data if raw 4-D data is provided. # Pack data if raw 4-D data is provided.
# This can only happen when autotuning. # This can only happen when autotuning.
if len(data.shape) == 4: if len(data.shape) == 4:
data, kernel = _pack_data(cfg, data, kernel) if autotvm.GLOBAL_SCOPE.in_tuning:
_, _, _, _, in_channel_block = get_const_tuple(data.shape) # Directly use modified data layout placeholder.
out_channel_chunk, _, _, _, _, out_channel_block \ in_channel_block = cfg["tile_ic"].size[-1]
= get_const_tuple(kernel.shape) in_channel_chunk = in_channel // in_channel_block
out_channel_block = cfg["tile_oc"].size[-1]
out_channel_chunk = out_channel // out_channel_block
dshape = (batch, in_channel_chunk, in_height, in_width, in_channel_block)
data = tvm.te.placeholder(dshape, data.dtype, name="data")
kshape = (out_channel_chunk, 1, filter_height, filter_width, 1, out_channel_block)
kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel")
else:
data, kernel = _pack_data(cfg, data, kernel)
_, _, _, _, in_channel_block = get_const_tuple(data.shape)
out_channel_chunk, _, _, _, _, out_channel_block \
= get_const_tuple(kernel.shape)
# padding stage # padding stage
DOPAD = (pad_top != 0 or pad_left != 0 or pad_down != 0 or pad_right != 0) DOPAD = (pad_top != 0 or pad_left != 0 or pad_down != 0 or pad_right != 0)
...@@ -207,16 +217,9 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out ...@@ -207,16 +217,9 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out
if isinstance(s[data_vec].op, tvm.te.ComputeOp) \ if isinstance(s[data_vec].op, tvm.te.ComputeOp) \
and "pad" in data_vec.op.tag: and "pad" in data_vec.op.tag:
batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
s[data_vec].vectorize(ic_block)
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis) s[data_vec].parallel(parallel_axis)
data_vec = data_vec.op.input_tensors[0]
if autotvm.GLOBAL_SCOPE.in_tuning:
# only in autotuning, input data of conv2d_NCHWc will be 4-D.
# skip this part during tuning to make recrods accurate.
# this part will be folded during Relay fold_constant pass.
s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region")
s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region")
C, O = conv_out, output C, O = conv_out, output
CC = s.cache_write(C, 'global') CC = s.cache_write(C, 'global')
...@@ -264,12 +267,12 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out ...@@ -264,12 +267,12 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out
@depthwise_conv2d_infer_layout.register("cpu") @depthwise_conv2d_infer_layout.register("cpu")
def _depthwise_conv2d_infer_layout(workload, cfg): def _depthwise_conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, dtype = workload _, data, kernel, strides, padding, dilation, _, _, dtype = workload
batch_size, in_channel, in_height, in_width = data[1] batch_size, in_channel, in_height, in_width = data[1]
filter_channel, channel_multiplier, k_height, k_width = kernel[1] filter_channel, channel_multiplier, k_height, k_width = kernel[1]
out_channel = filter_channel * channel_multiplier out_channel = filter_channel * channel_multiplier
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1 out_height = (in_height + padding[0] + padding[2] - k_height) // strides[0] + 1
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1 out_width = (in_width + padding[1] + padding[3] - k_width) // strides[1] + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic) in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic in_layout = "NCHW%dc" % tile_ic
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment