Commit 0708c48d by eqy Committed by Tianqi Chen

[WIP][AUTOTVM][TOPI] Port x86 NCHWc to AutoTVM for Task Extraction (#2664)

[AUTOTVM][TOPI] Port x86 NCHWc to AutoTVM for Task Extraction
parent 7cd986db
......@@ -50,7 +50,7 @@ def extract_from_program(func, params, ops, target, target_host=None):
# relay op -> topi compute
OP2TOPI = {
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
......
......@@ -67,6 +67,7 @@ class TaskExtractEnv:
topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
topi.nn.dense: "topi_nn_dense",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
......@@ -80,6 +81,7 @@ class TaskExtractEnv:
topi.generic.schedule_depthwise_conv2d_nhwc],
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
......@@ -108,7 +110,6 @@ class TaskExtractEnv:
key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection:
self.task_collection.append(key)
return compute_func.fdefault(*args)
_local_scope(topi_compute)
......@@ -205,6 +206,15 @@ class TaskExtractEnv:
s = topi.generic.schedule_deformable_conv2d_nchw([C])
return s, [A, Offset, W, C]
@register("topi_nn_conv2d_NCHWc")
def _topi_nn_conv2d_NCHWc(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.conv2d_NCHWc(*args, **kwargs)
s = topi.generic.schedule_conv2d_NCHWc([C])
return s, [A, W, C]
def reset(self, wanted_topi_funcs):
"""Reset task collections
......
......@@ -329,7 +329,68 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
"""
# search platform specific declaration first
# default declaration
raise ValueError("missing register for topi.nn.conv2d_NCHWc")
# layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding,
(dilated_kernel_h,
dilated_kernel_w))
HPAD = pad_top + pad_down
WPAD = pad_left + pad_right
HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride)
dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
assert (dh, dw) == (1, 1), "Does not support dilation"
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
if data.dtype == 'uint8':
oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape)
else:
oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
# output shape
out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1
out_width = (iw + 2 * WPAD - kernel_width) // WSTR + 1
oshape = (n, oc_chunk, out_height, out_width, oc_bn)
# DOPAD
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
if data.dtype == 'uint8':
assert out_dtype == "int32", \
"INT8 convolution requires input dtype = uint8 and output dtype=int32"
# Intel performs dot product of 2 "4" Int8 values
# Current implementation requires ic_bn to be a multiple of 4
n_elems = 4
assert ic_bn % n_elems == 0
ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner]
.astype(out_dtype) *
kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
# else: fp implementation
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw,
ic%ic_bn].astype(out_dtype) *
kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block],
axis=[ic, kh, kw]),
name='conv2d_NCHWc', tag="conv2d_NCHWc")
def conv2d_winograd_weight_transform(kernel, tile_size):
......
......@@ -2,6 +2,7 @@
"""Conv2D schedule on x86"""
import logging
import re
import tvm
from tvm import autotvm
......@@ -41,9 +42,22 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
"""Create schedule configuration from input arguments"""
dshape = get_const_tuple(data.shape)
kshape = get_const_tuple(kernel.shape)
pat = re.compile(r'NCHW.+(\d+)c')
if layout == 'NCHW':
n, ic, h, w = dshape
oc, _, kh, kw = kshape
elif pat.match(layout) is not None:
n, ic_chunk, h, w, ic_bn = dshape
if data.dtype == 'uint8':
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk*ic_bn
assert ic == k_ic*k_ic_f*kic_s
else:
oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
assert ic_chunk == k_ic_chunk
assert ic_bn == k_ic_bn
ic = ic_chunk*ic_bn
oc = oc_chunk*oc_bn
else:
raise ValueError("Not support this layout {} with "
"schedule template.".format(layout))
......@@ -258,7 +272,14 @@ def schedule_conv2d_nhwc(outs):
@autotvm.task.register("topi_x86_conv2d_NCHWc")
def _topi_nn_conv2d_NCHWc(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
data, kernel, strides, padding, dilation, origin_layout, dtype = deserialize_args(args)
args = deserialize_args(args)
if len(args) == 7:
data, kernel, strides, padding, dilation, origin_layout, dtype = args
else:
assert len(args) == 8
data, kernel, strides, padding, dilation, origin_layout, out_layout, dtype = args
raw_data_shape = get_const_tuple(data.shape)
raw_kernel_shape = get_const_tuple(kernel.shape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment