Commit b9e8826f by Yizhi Liu Committed by Tianqi Chen

Refine porting x86 NCHWc conv to AutoTVM (#1993)

parent e286e637
......@@ -167,16 +167,16 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
kh, kw = attrs.get_int_tuple('kernel_size')
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
out_dtype = attrs.get_string("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "not support dilate now"
if groups == 1:
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw),
strides, padding, layout, out_layout)
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding,
layout, out_layout, out_dtype)
# pylint: enable=assignment-from-no-return
else:
raise ValueError("not support arbitrary group number > 1 for now")
......@@ -190,16 +190,9 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of conv2d NCHWc"""
groups = attrs.get_int("groups")
kh, kw = attrs.get_int_tuple('kernel_size')
oc = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
with tvm.target.create(target):
if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding,
layout, out_layout, outs)
return topi.generic.schedule_conv2d_NCHWc(outs)
else:
raise ValueError("not support group number > 1 for now")
......
......@@ -60,6 +60,53 @@ class DispatchContext(object):
ret = self._old_ctx.query(target, workload)
return ret
def update(self, target, workload, cfg):
"""
Update context with a specific config.
Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
cfg : ConfigSpace
The specific configuration.
Note
----
This interface is for cases when TVM decides to replace an operator in the graph.
For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW`
convolution with `NCHW[x]c` implementation on x86 CPUs.
Thus in TOPI, we first query schedule using original `NCHW` workload,
then update the dispatcher with the new `NCHW[x]c` workload.
So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using
its own workload directly.
.. code-block:: python
@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo):
workload = get_conv2d_workload(...)
dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target()
config = dispatch_ctx.query(target, workload)
# Get conv2d_NCHWc workload from config
# new_workload = ...
# new_inputs = ...
# new_attrs = ...
# Store altered operator's config
dispatch_ctx.update(target, new_workload, config)
return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs)
We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc`
share the same schedule parameters.
One can construct a new `ConfigEntity` if this is not the case.
"""
raise NotImplementedError()
def _query_inside(self, target, workload):
"""
Query the context to get the specific config for a template.
......@@ -179,6 +226,11 @@ class ApplyConfig(DispatchContext):
self.workload = workload
return self._config
def update(self, target, workload, cfg):
"""Override update"""
self.workload = workload
self._config = cfg
class ApplyHistoryBest(DispatchContext):
"""
......@@ -197,6 +249,7 @@ class ApplyHistoryBest(DispatchContext):
self.best_by_targetkey = {}
self.best_by_model = {}
self._best_user_defined = {}
if records:
self.load(records)
......@@ -264,17 +317,32 @@ class ApplyHistoryBest(DispatchContext):
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_model:
return self.best_by_model[key][0].config
# then try matching by target key
for k in target.keys:
key = (k, workload)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config
return None
def update(self, target, workload, cfg):
for opt in target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
self._best_user_defined[key] = cfg
for k in target.keys:
key = (k, workload)
self._best_user_defined[key] = cfg
class FallbackContext(DispatchContext):
"""
......@@ -324,6 +392,10 @@ class FallbackContext(DispatchContext):
if key in self.memory:
del self.memory[key]
def update(self, target, workload, cfg):
key = (str(target), workload)
self.memory[key] = cfg
DispatchContext.current = FallbackContext()
def clear_fallback_cache(target, workload):
......@@ -391,37 +463,14 @@ class ApplyGraphBest(DispatchContext):
cfg : ConfigSpace
The specific configuration.
"""
cfg = self._records[self._counter][0].config
self._counter += 1
return cfg
def query_global_dict(self, key):
"""
Query the context to get config from global
config dictionary.
Parameters
----------
key : str
Key to query the config.
Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
if self._counter < len(self._records):
cfg = self._records[self._counter][0].config
self._counter += 1
self.update(target, workload, cfg)
return cfg
key = (str(target), workload)
return self._global_cfg_dict[key]
def update_global_dict(self, key, val):
"""
Update the global config dictionary.
Parameters
----------
key : str
Key of config.
val : ConfigSpace
Value of config.
"""
self._global_cfg_dict[key] = val
def update(self, target, workload, cfg):
key = (str(target), workload)
self._global_cfg_dict[key] = cfg
# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ
# pylint: disable=consider-using-enumerate
# pylint: disable=consider-using-enumerate,too-many-lines
"""
Template configuration space.
......@@ -996,5 +996,17 @@ class FallbackConfigEntity(ConfigSpace):
if not isinstance(self.space_map[knob_name], SplitSpace):
self._entity_map[knob_name] = best_match_cfg[knob_name]
def __setitem__(self, name, entity):
"""set the entity(knob) of by name
Parameters
----------
name: str
name of the entity
entity: SplitEntity, ReorderEntity, AnnotateEntity, OtherOptionEntity
value of the entity
"""
self._entity_map[name] = entity
def __repr__(self):
return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
......@@ -182,7 +182,7 @@ def create(func_name, args, target, target_host=None, template_key=None):
return ret
def args_to_workload(x):
def args_to_workload(x, topi_compute_func=None):
"""Convert argument list to hashable workload tuple.
This function will convert list to tuple, tvm node to python value and
flatten tvm.tensor.Tensor to a tuple
......@@ -191,6 +191,8 @@ def args_to_workload(x):
----------
x: primitive hashable types or tensor.Tensor
The original value
topi_compute_func: topi compute function
The function name will be added as first element of the workload tuple
Returns
-------
......@@ -198,18 +200,19 @@ def args_to_workload(x):
The hashable value
"""
if isinstance(x, tensor.Tensor):
return get_const_tuple(x.shape) + (x.dtype, )
workload = get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)):
return tuple([args_to_workload(a) for a in x])
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
return x
workload = x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
workload = x.value
elif x is None:
return 0
workload = 0
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload
def template(func):
"""
......
# pylint: disable=unused-variable,invalid-name
# pylint: disable=unused-variable,invalid-name,unused-argument
"""
Decorators for registering tunable templates to TOPI.
......@@ -13,7 +13,6 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
from ... import _api_internal, tensor
from ..util import get_func_name
from .task import args_to_workload, dispatcher
......@@ -55,8 +54,6 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
--------
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
fname = get_func_name(topi_compute)
def _decorator(f):
targets = [target_keys] if isinstance(target_keys, str) else target_keys
for target_key in targets:
......@@ -68,7 +65,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
def config_dispatcher(*args, **kwargs):
"""override topi call as a config dispatcher"""
assert not kwargs, "Do not support kwargs in template function call"
return (fname, ) + args_to_workload(args)
return args_to_workload(args, topi_compute)
_REGISTED_DISPATHCER[target_key][topi_compute] = config_dispatcher
config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_compute]
......@@ -88,7 +85,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
attrs = {}
for k, v in node.op.attrs.items():
attrs[k] = v
attrs['workload'] = (fname, ) + args_to_workload(args)
attrs['workload'] = args_to_workload(args, topi_compute)
if isinstance(op, tensor.ComputeOp):
op = _api_internal._ComputeOp(
op.name, op.tag, attrs, op.axis, op.body)
......@@ -153,7 +150,7 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None)
if topi_schedule not in _REGISTED_DISPATHCER[target_key]:
@topi_schedule.register(target_key)
@dispatcher
def config_dispatcher(outs):
def config_dispatcher(outs, *args, **kwargs):
"""override topi call as a workload dispatcher"""
def traverse(tensors):
"""traverse all ops to find attached workload"""
......@@ -179,11 +176,11 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None)
config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_schedule]
@config_dispatcher.register(template_keys)
def template_call(cfg, outs):
def template_call(cfg, outs, *args, **kwargs):
"""call the schedule func"""
if f == topi_schedule.fdefault:
return f(outs)
return f(cfg, outs)
return f(outs, *args, **kwargs)
return f(cfg, outs, *args, **kwargs)
return f
......
......@@ -55,33 +55,15 @@ def schedule_conv2d_nhwc(outs):
@tvm.target.generic_func
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides,
padding, layout, out_layout, outs):
def schedule_conv2d_NCHWc(outs):
"""Schedule for conv2d_NCHW[x]c
Parameters
----------
num_filter : int
The number of filter, i.e., the output channel.
kernel_size : tuple of int
(kernel_height, kernel_width)
strides : tuple of int
(stride_of_height, stride_of_width)
padding : tuple of int
(pad_of_height, pad_of_width)
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor
The computation graph description of conv2d_NCHWc
in the format of an array of tensors.
The number of filter, i.e., the output channel.
Returns
-------
......
......@@ -73,30 +73,11 @@ def schedule_conv2d_nhwc(outs):
@generic.schedule_conv2d_NCHWc.register(["hls"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides,
padding, layout, out_layout, outs):
def schedule_conv2d_NCHWc(outs):
"""Schedule for conv2d_NCHW[x]c
Parameters
----------
num_filter : int
The number of filter, i.e., the output channel.
kernel_size : tuple of int
(kernel_height, kernel_width)
strides : tuple of int
(stride_of_height, stride_of_width)
padding : tuple of int
(pad_of_height, pad_of_width)
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor
The computation graph description of conv2d_NCHWc
in the format of an array of tensors.
......
......@@ -61,8 +61,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
@conv2d_NCHWc.register(["intel_graphics"])
def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout,\
out_layout, out_dtype='float32'):
def _decl_conv2d(data, kernel, stride, padding, layout, out_layout, out_dtype='float32'):
"""Conv2D operator for Intel Graphics backend.
Parameters
......@@ -101,7 +100,7 @@ def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout,
return _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype)
@generic.schedule_conv2d_NCHWc.register(["intel_graphics"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_layout, outs):
def schedule_conv2d_NCHWc(outs):
"""Schedule for conv2d_nchw for Intel Graphics
Parameters
......
......@@ -84,32 +84,6 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _get_workload_int8(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func
def _get_alter_layout_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule for conv2d_alter_layout. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
@tvm.target.generic_func
def _get_schedule(wkl):
......@@ -122,28 +96,6 @@ def _get_schedule(wkl):
return wkl
@tvm.target.generic_func
def _get_schedule_NCHWc(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
@tvm.target.generic_func
def _get_schedule_NCHWc_int8(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
"""Convolution operator in NCHW layout.
......@@ -302,8 +254,7 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
@tvm.target.generic_func
def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride,
padding, layout, out_layout, out_dtype='float32'):
def conv2d_NCHWc(data, kernel, stride, padding, layout, out_layout, out_dtype='float32'):
"""Conv2D operator for nChw[x]c layout.
Parameters
......@@ -316,12 +267,6 @@ def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride,
[num_filter_chunk, in_channel_chunk, filter_height, filter_width,
in_channel_block, num_filter_block]
num_filter : int
number of filters, i.e., output channel size
kernel_size : tuple of two ints
[kernel_height, kernel_width]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
......
# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
"""1x1 Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from tvm.autotvm.task import ConfigEntity
import topi
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.util import infer_pad
from ..nn.pad import pad
from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor'])
def _get_default_schedule(wkl, simd_width):
def _fallback_schedule(cfg, wkl, simd_width):
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
......@@ -37,45 +31,11 @@ def _get_default_schedule(wkl, simd_width):
if out_width % ow_factor == 0:
for oh_factor in range(out_height, 0, -1):
if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
return AVXConv1x1Fwd(ic_bn, oc_bn, oh_factor, ow_factor)
raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
def _fallback_schedule(wkl, simd_width):
batch_size, in_channel, height, width, _ = wkl[1]
out_channel, _, hkernel, wkernel, _ = wkl[2]
HPAD, WPAD = wkl[4]
HSTR, WSTR = wkl[3]
out_height = (height + 2 * HPAD - hkernel) // HSTR + 1
out_width = (width + 2 * WPAD - wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if out_channel % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if in_channel % bn == 0:
ic_bn = bn
break
for ow_factor in range(out_width, 0, -1):
if out_width % ow_factor == 0:
for oh_factor in range(out_height, 0, -1):
if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [in_channel // ic_bn, ic_bn]],
["tile_oc", "sp", [out_channel // oc_bn, oc_bn]],
["tile_oh", "ot", oh_factor],
["tile_ow", "sp", [out_width // ow_factor,
ow_factor]],],
"t": ""}
return ConfigEntity.from_json_dict(cfg_dict)
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_oh"] = OtherOptionEntity(oh_factor)
cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor])
return
raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
......@@ -148,8 +108,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu
def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
# fetch schedule
ic_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oh"].val,
cfg["tile_ow"].size[-1])
oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
_, _, _, _, ic_bn = get_const_tuple(data.shape)
# schedule data
A = data
......@@ -201,57 +161,13 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
return s
def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel):
""" Declaration for int8 conv"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
# Intel performs dot product of 2 "4" Int8 values
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
# Reshaping kernel as the last 2 dimensions are 1x1 (k_h x k_w)
k_shape = kernel.shape
kernel = topi.reshape(kernel, (k_shape[0], k_shape[1], k_shape[2], k_shape[3],
k_shape[4] * k_shape[5] * k_shape[6]))
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR, ow*WSTR,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")
return conv
def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
"""
Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
target = tvm.target.current_target(allow_none=False)
int32_lanes = -1
if check_skylake(target):
......@@ -260,6 +176,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
return s
assert int32_lanes != -1
oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
_, _, _, _, ic_bn = get_const_tuple(data.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
# schedule data
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
......@@ -271,8 +191,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor)
oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[C].split(ow, factor=ow_factor)
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block)
......@@ -282,17 +202,17 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
s[C].parallel(parallel_axis)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
# Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0
assert oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor)
oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_outer, ic_f_inner, oh_inner,
s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner,
ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].fuse(oc_chunk, oh_outer)
......@@ -303,8 +223,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor)
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer)
......
# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
"""Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from tvm.autotvm.task import ConfigEntity
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.util import infer_pad
from ..nn.pad import pad
from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw'])
def _get_default_schedule(wkl, simd_width):
def _fallback_schedule(cfg, wkl, simd_width):
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
......@@ -36,42 +32,10 @@ def _get_default_schedule(wkl, simd_width):
reg_n = n
break
return AVXConvCommonFwd(ic_bn, oc_bn, reg_n, False)
def _fallback_schedule(wkl, simd_width):
batch_size, in_channel, height, width, _ = wkl[1]
out_channel, _, hkernel, wkernel, _ = wkl[2]
HPAD, WPAD = wkl[4]
HSTR, WSTR = wkl[3]
out_width = (width + 2 * WPAD - wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if out_channel % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if in_channel % bn == 0:
ic_bn = bn
break
reg_n = 1
for n in range(31, 0, -1):
if out_width % n == 0:
reg_n = n
break
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [in_channel // ic_bn, ic_bn]],
["tile_oc", "sp", [out_channel // oc_bn, oc_bn]],
["tile_ow", "sp", [out_width // reg_n, reg_n]],
["unroll_kw", "ot", False]],
"t": ""}
return ConfigEntity.from_json_dict(cfg_dict)
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
cfg["unroll_kw"] = OtherOptionEntity(False)
def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
......@@ -147,8 +111,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu
def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
# fetch schedule
ic_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_ow"].size[-1],
cfg["unroll_kw"].val)
reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
_, _, _, _, ic_bn = get_const_tuple(data.shape)
# schedule data
A = data
......@@ -197,52 +161,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
return s
def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel):
"""
This function sets up the compute for INT8 conv 2d
Inputs are in INT8 datatype
Output is in INT32 datatype
"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
# pack data
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
# convolution
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
kh = tvm.reduce_axis((0, wkl.hkernel), name='kh')
kw = tvm.reduce_axis((0, wkl.wkernel), name='kw')
# Intel performs dot product of 2 "4" Int8 values
# Current implementation requires ic_bn to be a multiple of 4
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")
return conv
def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
"""
Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations
......@@ -263,6 +182,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
return s
assert int32_lanes != -1
reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
_, _, _, _, ic_bn = get_const_tuple(data.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, _ = s[A].op.axis
......@@ -274,7 +197,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n)
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh)
s[C].vectorize(oc_block)
......@@ -285,14 +208,14 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n)
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
# Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0
assert oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
if sch.unroll_kw:
if unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].unroll(kw)
......@@ -308,7 +231,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n)
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
......
......@@ -54,19 +54,11 @@ def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad,
data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
if out_dtype == 'int32':
if k_h != 1:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
else:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES//4,
NUM_VEC_LANES, 4, k_h, k_w)
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
elif out_dtype == 'float32':
if k_h != 1:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES, NUM_VEC_LANES)
else:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES,
NUM_VEC_LANES, k_h, k_w)
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES, NUM_VEC_LANES)
out_height = (im_height + 2 * hpad - k_h) // hstride + 1
out_width = (im_width + 2 * wpad - k_w) // wstride + 1
o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
......@@ -103,8 +95,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
with tvm.target.create(TARGET_NAME):
conv = topi.nn.conv2d_NCHWc(data, kernel, num_filter=out_filter,
kernel_size=(k_h, k_w), stride=hstride,
conv = topi.nn.conv2d_NCHWc(data, kernel, stride=hstride,
padding=hpad, layout='NCHWc',
out_layout='NCHWc', out_dtype=out_dtype)
out = topi.nn.relu(conv)
......@@ -114,13 +105,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
# Generate and run the optimized schedule
sconv = topi.generic.nn.schedule_conv2d_NCHWc(num_filter=out_filter,
kernel_size=(k_h, k_w),
strides=hstride,
padding=hpad,
layout='NCHWc',
out_layout='NCHWc',
outs=[out])
sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out])
func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv')
func(data_array, kernel_array, c_sch)
......
"""Test for NCHW[x]c convolution"""
import numpy as np
import tvm
from tvm import autotvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend
def _transform_data(data, bn):
# NCHW -> NCHW[x]c
batch_size, channel, height, width = data.shape
data = np.transpose(data, (0, 2, 3, 1))
data = np.reshape(data, (batch_size, height, width, channel//bn, bn))
data = np.transpose(data, (0, 3, 1, 2, 4))
return data
def _transform_kernel(kernel, ic_bn, oc_bn):
# OIHW -> OIHW[x]i[x]o
out_channel, in_channel, kh, kw = kernel.shape
kernel = np.transpose(kernel, (1, 2, 3, 0))
kernel = np.reshape(kernel, (in_channel, kh, kw, out_channel//oc_bn, oc_bn))
kernel = np.transpose(kernel, (1, 2, 3, 4, 0))
kernel = np.reshape(kernel, (kh, kw, out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn))
kernel = np.transpose(kernel, (2, 4, 0, 1, 5, 3))
return kernel
def _transform_bias(bias, bn):
# [num_filter, 1, 1] -> [num_filter//bn, 1, 1, bn]
num_filter, h, w = bias.shape
bias = np.transpose(bias, (1, 2, 0))
bias = np.reshape(bias, (h, w, num_filter//bn, bn))
bias = np.transpose(bias, (2, 0, 1, 3))
return bias
def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"):
assert dilation == 1, "conv2d_NCHWc does not support dilation for now."
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding))
in_height = in_width = in_size
# for testing functionality,
# we choose arbitrary block size that can divide the channel,
# regardless of the performance.
oc_block = 1
for bn in range(16, 0, -1):
if num_filter % bn == 0:
oc_block = bn
break
ic_block = 1
for bn in range(oc_block, 0, -1):
if in_channel % bn == 0:
ic_block = bn
break
A = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A')
W = tvm.placeholder((num_filter//oc_block, in_channel//ic_block, kernel, kernel, ic_block, oc_block), name='W')
bias = tvm.placeholder((num_filter//oc_block, 1, 1, oc_block), name='bias')
@memoize("topi.tests.test_topi_conv2d_NCHWc.verify_conv2d_NCHWc")
def get_ref_data():
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
w_np = np.random.uniform(size=(num_filter, in_channel, kernel, kernel)).astype(dtype)
b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype)
c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
if add_bias:
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \
_transform_bias(b_np, oc_block), _transform_data(c_np, oc_block)
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding),
layout='NCHW%dc'%ic_block,
out_layout="NCHW%dc"%oc_block,
out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_conv2d_NCHWc([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
# test llvm only for now since conv2d_NCHWc implement is missing in other backend.
for device in ["llvm"]:
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
if __name__ == "__main__":
# ResNet18 workloads
verify_conv2d_NCHWc(1, 3, 224, 64, 7, 2, 3)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc(1, 64, 56, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 64, 56, 128, 3, 2, 1)
verify_conv2d_NCHWc(1, 64, 56, 128, 1, 2, 0)
verify_conv2d_NCHWc(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_NCHWc(1, 128, 28, 256, 3, 2, 1)
verify_conv2d_NCHWc(1, 128, 28, 256, 1, 2, 0)
verify_conv2d_NCHWc(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_NCHWc(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_NCHWc(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_NCHWc(1, 512, 7, 512, 3, 1, 1)
# bias, relu
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_relu=True)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
# disable dilation test since it is not supported by NCHW[x]c conv for now.
# verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, dilation=2)
# batch size
verify_conv2d_NCHWc(4, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc(9, 64, 56, 64, 3, 1, 1)
# weird workloads
verify_conv2d_NCHWc(2, 2, 2, 2, 2, 2, 2)
verify_conv2d_NCHWc(3, 3, 3, 3, 3, 3, 3)
verify_conv2d_NCHWc(4, 4, 4, 4, 4, 4, 4)
verify_conv2d_NCHWc(5, 5, 5, 5, 5, 5, 5)
verify_conv2d_NCHWc(6, 6, 6, 6, 6, 6, 6)
# disable these tests due to some bugs of llvm with nvptx
# verify_conv2d_NCHWc(1, 1, 1, 1, 1, 1, 1, dilation=1)
# verify_conv2d_NCHWc(1, 1, 1, 1, 1, 1, 1, dilation=2)
# verify_conv2d_NCHWc(2, 13, 71, 59, 3, 1, 1)
# inception v3 workloads
verify_conv2d_NCHWc(1, 3, 299, 32, 3, 2, 0)
verify_conv2d_NCHWc(1, 32, 149, 32, 3, 1, 0)
verify_conv2d_NCHWc(1, 32, 147, 64, 3, 1, 1)
verify_conv2d_NCHWc(1, 64, 73, 80, 1, 1, 0)
verify_conv2d_NCHWc(1, 80, 73, 192, 3, 1, 0)
verify_conv2d_NCHWc(1, 192, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 192, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc(1, 48, 35, 64, 5, 1, 2)
verify_conv2d_NCHWc(1, 64, 35, 96, 3, 1, 1)
verify_conv2d_NCHWc(1, 96, 35, 96, 3, 1, 1)
verify_conv2d_NCHWc(1, 192, 35, 32, 1, 1, 0)
verify_conv2d_NCHWc(1, 256, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 256, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc(1, 288, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 288, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc(1, 288, 35, 384, 3, 2, 0)
verify_conv2d_NCHWc(1, 96, 35, 96, 3, 2, 0)
verify_conv2d_NCHWc(1, 768, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 768, 17, 128, 1, 1, 0)
verify_conv2d_NCHWc(1, 128, 17, 128, 1, 1, 0)
verify_conv2d_NCHWc(1, 128, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc(1, 128, 17, 128, 7, 1, 3)
verify_conv2d_NCHWc(1, 128, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 768, 17, 160, 1, 1, 0)
verify_conv2d_NCHWc(1, 160, 17, 160, 1, 1, 0)
verify_conv2d_NCHWc(1, 160, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc(1, 160, 17, 160, 7, 1, 3)
verify_conv2d_NCHWc(1, 160, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 192, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 192, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc(1, 192, 17, 320, 3, 2, 0)
verify_conv2d_NCHWc(1, 192, 17, 192, 3, 2, 0)
verify_conv2d_NCHWc(1, 1280, 8, 320, 1, 1, 0)
verify_conv2d_NCHWc(1, 1280, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc(1, 384, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc(1, 384, 8, 384, 3, 1, 1)
verify_conv2d_NCHWc(1, 1280, 8, 448, 1, 1, 0)
verify_conv2d_NCHWc(1, 448, 8, 384, 3, 1, 1)
verify_conv2d_NCHWc(1, 1280, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 320, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 448, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 1024, 19, 84, 3, 1, 1)
verify_conv2d_NCHWc(1, 2048, 10, 126, 3, 1, 1)
verify_conv2d_NCHWc(1, 512, 5, 126, 3, 1, 1)
verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1)
......@@ -14,7 +14,6 @@ import nnvm.compiler
import tvm
from tvm import autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from topi.x86.conv2d import conv_NCHWc_arg_to_workload
import tvm.contrib.graph_runtime as runtime
#################################################################
......@@ -118,17 +117,9 @@ def tune_kernels(tasks,
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# converting conv2d tasks to conv2d_NCHWc tasks
# data, kernel are tuples of ("TENSOR", shape, dtype)
data, kernel, strides, padding, layout, dtype = tsk.args
kernel_size = (kernel[1][2], kernel[1][3])
data_plc = tvm.placeholder(data[1], name="data")
kernel_plc = tvm.placeholder(kernel[1], name="kernel")
args = [data_plc, kernel_plc, kernel[1][0], kernel_size, strides,
padding, layout, layout, dtype]
args = autotvm.task.nnvm_integration.serialize_args(args)
task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=args, target=target)
task.workload = conv_NCHWc_arg_to_workload(data_plc, kernel_plc, kernel_size,
strides, padding, layout, layout, dtype)
task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=tsk.args,
target=target, template_key='direct')
task.workload = tsk.workload
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment