Commit 4d875d1f by Leyuan Wang Committed by Yao Wang

[TOPI] Add valid auto tvm for Intel Graphics (#4078)

* add valid autotune

* fix pylint
parent f2abd9f6
......@@ -23,6 +23,8 @@ import tvm
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from tvm.autotvm.task.topi_integration import deserialize_args
from tvm.autotvm.task import get_config
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, conv2d_infer_layout
from ..nn.util import get_pad_tuple
from ..nn.depthwise_conv2d import depthwise_conv2d_nchw
......@@ -153,6 +155,38 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None
s[tensor].bind(xi, thread_x)
return xi, thread_z, thread_y, thread_x
# Define template function for autotvm task
# We define schedule template in this function instead of
# declaration function since actual input arguments need
# to be altered by the schedule selected.
@autotvm.task.register("topi_intel_graphics_conv2d_NCHWc")
def __topi_nn_conv2d_NCHWc(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
data, kernel, strides, padding, dilation, layout, dtype = deserialize_args(args)
raw_data_shape = get_const_tuple(data.shape)
raw_kernel_shape = get_const_tuple(kernel.shape)
# get config here
cfg = get_config()
_create_schedule_template(cfg, data, kernel, strides, padding, dilation, layout)
cfg.add_flop(1)
# change shape with the value in config
ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1]
oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1]
new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
raw_data_shape[2], raw_data_shape[3], ic_bn)
new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
new_data = tvm.placeholder(new_data_shape, data.dtype)
new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
C = _decl_cl_spatialpack_NCHWc(cfg, new_data, new_kernel, strides, padding, dilation, dtype)
s = _schedule_conv2d_NCHWc(cfg, [C])
return s, [new_data, new_kernel, C]
@conv2d_alter_layout.register(["intel_graphics"])
def _alter_conv2d_layout(attrs, inputs, tinfo, F):
import nnvm.symbol as sym
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment