Commit 866d458c by Lianmin Zheng Committed by Yizhi Liu

[TOPI][AUTOTVM] Improve style (#2034)

* [TOPI] Improve the style of using autotvm

* fix
parent bc48811f
......@@ -12,27 +12,43 @@ from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
get_pad_tuple, pad, conv2d_alter_layout
# reuse some compute declarations from ARM CPU
from ..arm_cpu.conv2d import _conv_arg_to_workload, _decl_spatial_pack,\
_winograd_conv_arg_to_workload, _alter_conv2d_layout_arm
from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm
@conv2d.register('mali')
@autotvm.task.dispatcher
def conv2d_mali(data, kernel, strides, padding, layout, out_dtype):
"""TOPI compute callback. Mark this function as a dispatcher, so
this template can assign config according to workload
@autotvm.register_topi_compute(conv2d, 'mali', ['direct'])
def conv2d_mali(cfg, data, kernel, strides, padding, layout, out_dtype):
"""TOPI compute callback for conv2d
Parameters
----------
cfg: ConfigEntity
The config for this template
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
kernel : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] or
pre-packed 5-D with shape [num_filter_chunk, in_channel, filter_height,
filter_width, num_filter_block]
strides : list of two ints
[stride_height, stride_width]
padding : list of two ints
[pad_height, pad_width]
layout : str
layout of data
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
workload: Tuple
Dispatcher will use this workload to query corresponding config.
Then use cfg.template_key to call a registered template.
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
@conv2d_mali.register(['direct'])
def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
"""spatial packing template"""
return _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile=3)
@autotvm.register_topi_schedule(schedule_conv2d_nchw, 'mali', ['direct', 'winograd'])
......@@ -158,8 +174,8 @@ def _pick_tile_size(data, kernel):
else:
return 2
@conv2d_mali.register('winograd')
def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
@autotvm.register_topi_compute(conv2d, 'mali', ['winograd'])
def conv2d_mali_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
tile_size = _pick_tile_size(data, kernel)
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
......@@ -305,9 +321,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_
# thw following term is used to make the padding effective,
# otherwise the padding will be eliminated by bound inference
+ tvm.const(0, out_dtype) * M[alpha-1][alpha-1][CO-1][P_round-1],
name='output', tag='winograd_conv2d_output',
attrs={'workload': _winograd_conv_arg_to_workload(
data, kernel, strides, padding, layout, out_dtype, tile_size)})
name='output', tag='winograd_conv2d_output')
# we have to manually assign effective GFLOP for winograd
cfg.add_flop(2 * N * CO * H * W * KH * KW * CI)
......@@ -410,29 +424,15 @@ def _schedule_winograd(cfg, s, op):
s[Y].compute_at(s[output], tt)
@conv2d_alter_layout.register(["mali"])
def _alter_conv2d_layout(attrs, inputs, tinfos):
try:
return _alter_conv2d_layout_arm(attrs, inputs, tinfos)
except KeyError: # to filter out fallback opencl templates
return None
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@conv2d_winograd_without_weight_transform.register(['mali'])
@autotvm.task.dispatcher
def winograd_ww_config_dispatcher_(data, kernel, strides, padding, layout, out_dtype, tile_size):
return _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype,
tile_size)
@winograd_ww_config_dispatcher_.register(['winograd'])
def decl_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype,
tile_size)
@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'mali', ['winograd'])
def conv2d_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
"""TOPI compute callback"""
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
@autotvm.task.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
'mali', ['winograd'])
@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
'mali', ['winograd'])
def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
"""TOPI schedule callback"""
s = tvm.create_schedule([x.op for x in outs])
......@@ -445,6 +445,15 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
return s
##### REGISTER ALTER OP LAYOUT #####
@conv2d_alter_layout.register(["mali"])
def _alter_conv2d_layout(attrs, inputs, tinfos):
try:
return _alter_conv2d_layout_arm(attrs, inputs, tinfos)
except KeyError: # to filter out fallback opencl templates
return None
##### SCHECULE UTILITIES #####
def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None):
""" tile and bind to GPU threads """
......
......@@ -85,17 +85,6 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func
def _get_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
"""Convolution operator in NCHW layout.
......
......@@ -3,7 +3,7 @@
import tvm
from tvm import autotvm
from tvm.autotvm.task.nnvm_integration import deserialize_args
from tvm.autotvm.task import register, get_config
from tvm.autotvm.task import get_config
from .. import generic, tag
from .. import nn
from ..util import get_const_tuple
......@@ -145,7 +145,7 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtyp
return unpack
@autotvm.task.register_topi_schedule(generic.schedule_conv2d_nchw, 'cpu', ['direct'])
@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, 'cpu', ['direct'])
def schedule_conv2d(cfg, outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
......@@ -248,7 +248,7 @@ def schedule_conv2d_nhwc(outs):
# We define schedule template in this function instead of
# declaration function since actual input arguments need
# to be altered by the schedule selected.
@register("topi_x86_conv2d_NCHWc")
@autotvm.task.register("topi_x86_conv2d_NCHWc")
def _topi_nn_conv2d_NCHWc(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
data, kernel, strides, padding, origin_layout, dtype = deserialize_args(args)
......@@ -311,7 +311,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
# Store altered operator's config
# Store the same config for the altered operator (workload)
new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
dtype=data.dtype)
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment