Commit 866d458c by Lianmin Zheng Committed by Yizhi Liu

[TOPI][AUTOTVM] Improve style (#2034)

* [TOPI] Improve the style of using autotvm

* fix
parent bc48811f
......@@ -12,34 +12,40 @@ from ..util import traverse_inline, get_const_tuple, const_matrix
from ..nn import pad, conv2d, conv2d_alter_layout, conv2d_winograd_without_weight_transform
from ..nn.util import get_const_int, get_pad_tuple
def _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype):
"""convert argument to workload"""
if len(kernel.shape) == 4:
raw_kernel = kernel
else: # the input kernel is transformed by alter_op_layout
shape = get_const_tuple(kernel.shape)
raw_kernel = tvm.placeholder((shape[0] * shape[4], shape[1], shape[2], shape[3]),
dtype=kernel.dtype)
return ('conv2d', ) + autotvm.task.args_to_workload(
[data, raw_kernel, strides, padding, layout, out_dtype])
@conv2d.register('arm_cpu')
@autotvm.task.dispatcher
def conv2d_arm_cpu(data, kernel, strides, padding, layout, out_dtype):
"""TOPI compute callback. Mark this function as a dispatcher, so
this template can assign config according to workload
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
def conv2d_arm_cpu(cfg, data, kernel, strides, padding, layout, out_dtype):
"""TOPI compute callback for conv2d
Parameters
----------
cfg: ConfigEntity
The config for this template
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
kernel : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] or
pre-packed 5-D with shape [num_filter_chunk, in_channel, filter_height,
filter_width, num_filter_block]
strides : list of two ints
[stride_height, stride_width]
padding : list of two ints
[pad_height, pad_width]
layout : str
layout of data
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
workload: Tuple
Dispatcher will use this workload to query corresponding config.
Then use cfg.template_key to call a registered template.
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
@conv2d_arm_cpu.register(['direct'])
def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
"""spatial packing template"""
return _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile=2)
@autotvm.register_topi_schedule(schedule_conv2d_nchw, 'arm_cpu', ['direct', 'winograd'])
......@@ -93,8 +99,6 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile):
assert layout == "NCHW", "Only support NCHW"
# create workload according to raw arguments
wkl = _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
out_dtype = out_dtype or data.dtype
N, CI, IH, IW = get_const_tuple(data.shape)
if len(kernel.shape) == 4:
......@@ -177,8 +181,7 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_conv2d_output',
attrs={'workload': wkl})
name='output_unpack', tag='spatial_conv2d_output')
return output
def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
......@@ -238,16 +241,13 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
return s
@conv2d_arm_cpu.register('winograd')
def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
""" TOPI compute callback. Use winograd template """
tile_size = 4
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
# create workload according to raw arguments
wkl = _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout,
out_dtype, tile_size)
N, CI, IH, IW = get_const_tuple(data.shape)
if len(kernel.shape) == 4:
pre_computed = False
......@@ -368,8 +368,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_
# unpack output
output = tvm.compute((N, K, H, W), lambda n, k, h, w:
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
name='output', tag='winograd_conv2d_output',
attrs={'workload': wkl})
name='output', tag='winograd_conv2d_output')
# we have to manually assign effective GFLOP for winograd
cfg.add_flop(2 * N * K * H * W * KH * KW * C)
......@@ -458,36 +457,11 @@ def _schedule_winograd(cfg, s, output, last):
s[output].compute_inline()
def _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype, tile_size):
"""convert argument to workload"""
K = 3
shape = get_const_tuple(kernel.shape)
alpha = tile_size + K - 1
if len(kernel.shape) == 4:
assert shape[2:] == (K, K)
CO, CI = shape[:2]
else:
assert shape[:2] == (alpha, alpha)
CO, CI, VCO = shape[2:]
CO *= VCO
raw_kernel = tvm.placeholder((CO, CI, K, K), dtype=kernel.dtype)
return ('conv2d', ) + autotvm.task.args_to_workload(
[data, raw_kernel, strides, padding, layout, out_dtype])
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@conv2d_winograd_without_weight_transform.register(['arm_cpu'])
@autotvm.task.dispatcher
def winograd_ww_config_dispatcher_(data, kernel, strides, padding, layout, out_dtype, tile_size):
return _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype,
tile_size)
@winograd_ww_config_dispatcher_.register(['winograd'])
def decl_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype,
tile_size)
@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'arm_cpu', ['winograd'])
def conv2d_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
"""TOPI compute callback"""
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
......@@ -514,8 +488,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
new_attrs = {k: attrs[k] for k in attrs.keys()}
assert attrs.get_int_tuple("dilation") == (1, 1), "Does not support dilation " \
"when alter_op_layout is enabled"
dilation = attrs.get_int_tuple("dilation")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
groups = attrs.get_int('groups')
......@@ -523,21 +496,38 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
out_dtype = attrs["out_dtype"]
out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype
if groups == 1:
if layout != 'NCHW' or groups != 1 or dilation != (1, 1):
return None
data, kernel = tinfos[0:2]
N, CI, H, W = get_const_tuple(data.shape)
CO, _, KH, KW = get_const_tuple(kernel.shape)
# query config of this workload
workload = _conv_arg_to_workload(tinfos[0], tinfos[1], strides, padding,
layout, out_dtype)
cfg = autotvm.DispatchContext.current.query(tvm.target.current_target(), workload)
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, layout, out_dtype], conv2d)
target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload)
autotvm.task.clear_fallback_cache(target, workload)
return None
if cfg.template_key == 'direct': # packing weight tensor
new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
if cfg.template_key == 'direct': # pack weight tensor
VC = cfg['tile_co'].size[-1]
new_attrs['kernel_layout'] = 'OIHW%do' % VC
# Store the same config for the altered operator (workload)
new_data = data
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg)
return sym.conv2d(*copy_inputs, **new_attrs)
else: # pre-compute weight transformation in winograd
if "-device=arm_cpu" in tvm.target.current_target().options:
if "-device=arm_cpu" in target.options:
tile_size = 4
VC = cfg['tile_k'].size[-1]
else:
......@@ -545,16 +535,21 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
VC = cfg['tile_bna'].val
weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size)
CO, CI, KH, KW = get_const_tuple(tinfos[1].shape)
weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size)
weight = sym.reshape(weight,
shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3])
copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size
return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
# do nothing for depthwise convolution
return None
# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, new_attrs['layout'], out_dtype, tile_size],
conv2d_winograd_without_weight_transform)
dispatch_ctx.update(target, new_workload, cfg)
return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
......@@ -12,27 +12,43 @@ from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
get_pad_tuple, pad, conv2d_alter_layout
# reuse some compute declarations from ARM CPU
from ..arm_cpu.conv2d import _conv_arg_to_workload, _decl_spatial_pack,\
_winograd_conv_arg_to_workload, _alter_conv2d_layout_arm
from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm
@conv2d.register('mali')
@autotvm.task.dispatcher
def conv2d_mali(data, kernel, strides, padding, layout, out_dtype):
"""TOPI compute callback. Mark this function as a dispatcher, so
this template can assign config according to workload
@autotvm.register_topi_compute(conv2d, 'mali', ['direct'])
def conv2d_mali(cfg, data, kernel, strides, padding, layout, out_dtype):
"""TOPI compute callback for conv2d
Parameters
----------
cfg: ConfigEntity
The config for this template
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
kernel : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] or
pre-packed 5-D with shape [num_filter_chunk, in_channel, filter_height,
filter_width, num_filter_block]
strides : list of two ints
[stride_height, stride_width]
padding : list of two ints
[pad_height, pad_width]
layout : str
layout of data
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
workload: Tuple
Dispatcher will use this workload to query corresponding config.
Then use cfg.template_key to call a registered template.
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
@conv2d_mali.register(['direct'])
def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
"""spatial packing template"""
return _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile=3)
@autotvm.register_topi_schedule(schedule_conv2d_nchw, 'mali', ['direct', 'winograd'])
......@@ -158,8 +174,8 @@ def _pick_tile_size(data, kernel):
else:
return 2
@conv2d_mali.register('winograd')
def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
@autotvm.register_topi_compute(conv2d, 'mali', ['winograd'])
def conv2d_mali_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
tile_size = _pick_tile_size(data, kernel)
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
......@@ -305,9 +321,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_
# thw following term is used to make the padding effective,
# otherwise the padding will be eliminated by bound inference
+ tvm.const(0, out_dtype) * M[alpha-1][alpha-1][CO-1][P_round-1],
name='output', tag='winograd_conv2d_output',
attrs={'workload': _winograd_conv_arg_to_workload(
data, kernel, strides, padding, layout, out_dtype, tile_size)})
name='output', tag='winograd_conv2d_output')
# we have to manually assign effective GFLOP for winograd
cfg.add_flop(2 * N * CO * H * W * KH * KW * CI)
......@@ -410,28 +424,14 @@ def _schedule_winograd(cfg, s, op):
s[Y].compute_at(s[output], tt)
@conv2d_alter_layout.register(["mali"])
def _alter_conv2d_layout(attrs, inputs, tinfos):
try:
return _alter_conv2d_layout_arm(attrs, inputs, tinfos)
except KeyError: # to filter out fallback opencl templates
return None
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@conv2d_winograd_without_weight_transform.register(['mali'])
@autotvm.task.dispatcher
def winograd_ww_config_dispatcher_(data, kernel, strides, padding, layout, out_dtype, tile_size):
return _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype,
tile_size)
@winograd_ww_config_dispatcher_.register(['winograd'])
def decl_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype,
tile_size)
@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'mali', ['winograd'])
def conv2d_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
"""TOPI compute callback"""
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
@autotvm.task.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
'mali', ['winograd'])
def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
"""TOPI schedule callback"""
......@@ -445,6 +445,15 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
return s
##### REGISTER ALTER OP LAYOUT #####
@conv2d_alter_layout.register(["mali"])
def _alter_conv2d_layout(attrs, inputs, tinfos):
try:
return _alter_conv2d_layout_arm(attrs, inputs, tinfos)
except KeyError: # to filter out fallback opencl templates
return None
##### SCHECULE UTILITIES #####
def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None):
""" tile and bind to GPU threads """
......
......@@ -85,17 +85,6 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func
def _get_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
"""Convolution operator in NCHW layout.
......
......@@ -3,7 +3,7 @@
import tvm
from tvm import autotvm
from tvm.autotvm.task.nnvm_integration import deserialize_args
from tvm.autotvm.task import register, get_config
from tvm.autotvm.task import get_config
from .. import generic, tag
from .. import nn
from ..util import get_const_tuple
......@@ -145,7 +145,7 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtyp
return unpack
@autotvm.task.register_topi_schedule(generic.schedule_conv2d_nchw, 'cpu', ['direct'])
@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, 'cpu', ['direct'])
def schedule_conv2d(cfg, outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
......@@ -248,7 +248,7 @@ def schedule_conv2d_nhwc(outs):
# We define schedule template in this function instead of
# declaration function since actual input arguments need
# to be altered by the schedule selected.
@register("topi_x86_conv2d_NCHWc")
@autotvm.task.register("topi_x86_conv2d_NCHWc")
def _topi_nn_conv2d_NCHWc(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
data, kernel, strides, padding, origin_layout, dtype = deserialize_args(args)
......@@ -311,7 +311,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
# Store altered operator's config
# Store the same config for the altered operator (workload)
new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
dtype=data.dtype)
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment