Commit b9e8826f by Yizhi Liu Committed by Tianqi Chen

Refine porting x86 NCHWc conv to AutoTVM (#1993)

parent e286e637
...@@ -167,16 +167,16 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): ...@@ -167,16 +167,16 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
padding = attrs.get_int_tuple("padding") padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides") strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation") dilation = attrs.get_int_tuple("dilation")
kh, kw = attrs.get_int_tuple('kernel_size')
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs.get_string("layout") layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout") out_layout = attrs.get_string("out_layout")
out_dtype = attrs.get_string("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "not support dilate now" assert dilation == (1, 1), "not support dilate now"
if groups == 1: if groups == 1:
# pylint: disable=assignment-from-no-return # pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw), out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding,
strides, padding, layout, out_layout) layout, out_layout, out_dtype)
# pylint: enable=assignment-from-no-return # pylint: enable=assignment-from-no-return
else: else:
raise ValueError("not support arbitrary group number > 1 for now") raise ValueError("not support arbitrary group number > 1 for now")
...@@ -190,16 +190,9 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): ...@@ -190,16 +190,9 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
def schedule_contrib_conv2d_NCHWc(attrs, outs, target): def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of conv2d NCHWc""" """Schedule definition of conv2d NCHWc"""
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
kh, kw = attrs.get_int_tuple('kernel_size')
oc = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
with tvm.target.create(target): with tvm.target.create(target):
if groups == 1: if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding, return topi.generic.schedule_conv2d_NCHWc(outs)
layout, out_layout, outs)
else: else:
raise ValueError("not support group number > 1 for now") raise ValueError("not support group number > 1 for now")
......
...@@ -60,6 +60,53 @@ class DispatchContext(object): ...@@ -60,6 +60,53 @@ class DispatchContext(object):
ret = self._old_ctx.query(target, workload) ret = self._old_ctx.query(target, workload)
return ret return ret
def update(self, target, workload, cfg):
"""
Update context with a specific config.
Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
cfg : ConfigSpace
The specific configuration.
Note
----
This interface is for cases when TVM decides to replace an operator in the graph.
For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW`
convolution with `NCHW[x]c` implementation on x86 CPUs.
Thus in TOPI, we first query schedule using original `NCHW` workload,
then update the dispatcher with the new `NCHW[x]c` workload.
So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using
its own workload directly.
.. code-block:: python
@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo):
workload = get_conv2d_workload(...)
dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target()
config = dispatch_ctx.query(target, workload)
# Get conv2d_NCHWc workload from config
# new_workload = ...
# new_inputs = ...
# new_attrs = ...
# Store altered operator's config
dispatch_ctx.update(target, new_workload, config)
return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs)
We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc`
share the same schedule parameters.
One can construct a new `ConfigEntity` if this is not the case.
"""
raise NotImplementedError()
def _query_inside(self, target, workload): def _query_inside(self, target, workload):
""" """
Query the context to get the specific config for a template. Query the context to get the specific config for a template.
...@@ -179,6 +226,11 @@ class ApplyConfig(DispatchContext): ...@@ -179,6 +226,11 @@ class ApplyConfig(DispatchContext):
self.workload = workload self.workload = workload
return self._config return self._config
def update(self, target, workload, cfg):
"""Override update"""
self.workload = workload
self._config = cfg
class ApplyHistoryBest(DispatchContext): class ApplyHistoryBest(DispatchContext):
""" """
...@@ -197,6 +249,7 @@ class ApplyHistoryBest(DispatchContext): ...@@ -197,6 +249,7 @@ class ApplyHistoryBest(DispatchContext):
self.best_by_targetkey = {} self.best_by_targetkey = {}
self.best_by_model = {} self.best_by_model = {}
self._best_user_defined = {}
if records: if records:
self.load(records) self.load(records)
...@@ -264,17 +317,32 @@ class ApplyHistoryBest(DispatchContext): ...@@ -264,17 +317,32 @@ class ApplyHistoryBest(DispatchContext):
if opt.startswith("-model"): if opt.startswith("-model"):
model = opt[7:] model = opt[7:]
key = (model, workload) key = (model, workload)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_model: if key in self.best_by_model:
return self.best_by_model[key][0].config return self.best_by_model[key][0].config
# then try matching by target key # then try matching by target key
for k in target.keys: for k in target.keys:
key = (k, workload) key = (k, workload)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_targetkey: if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config return self.best_by_targetkey[key][0].config
return None return None
def update(self, target, workload, cfg):
for opt in target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
self._best_user_defined[key] = cfg
for k in target.keys:
key = (k, workload)
self._best_user_defined[key] = cfg
class FallbackContext(DispatchContext): class FallbackContext(DispatchContext):
""" """
...@@ -324,6 +392,10 @@ class FallbackContext(DispatchContext): ...@@ -324,6 +392,10 @@ class FallbackContext(DispatchContext):
if key in self.memory: if key in self.memory:
del self.memory[key] del self.memory[key]
def update(self, target, workload, cfg):
key = (str(target), workload)
self.memory[key] = cfg
DispatchContext.current = FallbackContext() DispatchContext.current = FallbackContext()
def clear_fallback_cache(target, workload): def clear_fallback_cache(target, workload):
...@@ -391,37 +463,14 @@ class ApplyGraphBest(DispatchContext): ...@@ -391,37 +463,14 @@ class ApplyGraphBest(DispatchContext):
cfg : ConfigSpace cfg : ConfigSpace
The specific configuration. The specific configuration.
""" """
if self._counter < len(self._records):
cfg = self._records[self._counter][0].config cfg = self._records[self._counter][0].config
self._counter += 1 self._counter += 1
self.update(target, workload, cfg)
return cfg return cfg
key = (str(target), workload)
def query_global_dict(self, key):
"""
Query the context to get config from global
config dictionary.
Parameters
----------
key : str
Key to query the config.
Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
return self._global_cfg_dict[key] return self._global_cfg_dict[key]
def update_global_dict(self, key, val): def update(self, target, workload, cfg):
""" key = (str(target), workload)
Update the global config dictionary. self._global_cfg_dict[key] = cfg
Parameters
----------
key : str
Key of config.
val : ConfigSpace
Value of config.
"""
self._global_cfg_dict[key] = val
# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ # pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ
# pylint: disable=consider-using-enumerate # pylint: disable=consider-using-enumerate,too-many-lines
""" """
Template configuration space. Template configuration space.
...@@ -996,5 +996,17 @@ class FallbackConfigEntity(ConfigSpace): ...@@ -996,5 +996,17 @@ class FallbackConfigEntity(ConfigSpace):
if not isinstance(self.space_map[knob_name], SplitSpace): if not isinstance(self.space_map[knob_name], SplitSpace):
self._entity_map[knob_name] = best_match_cfg[knob_name] self._entity_map[knob_name] = best_match_cfg[knob_name]
def __setitem__(self, name, entity):
"""set the entity(knob) of by name
Parameters
----------
name: str
name of the entity
entity: SplitEntity, ReorderEntity, AnnotateEntity, OtherOptionEntity
value of the entity
"""
self._entity_map[name] = entity
def __repr__(self): def __repr__(self):
return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash) return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
...@@ -182,7 +182,7 @@ def create(func_name, args, target, target_host=None, template_key=None): ...@@ -182,7 +182,7 @@ def create(func_name, args, target, target_host=None, template_key=None):
return ret return ret
def args_to_workload(x): def args_to_workload(x, topi_compute_func=None):
"""Convert argument list to hashable workload tuple. """Convert argument list to hashable workload tuple.
This function will convert list to tuple, tvm node to python value and This function will convert list to tuple, tvm node to python value and
flatten tvm.tensor.Tensor to a tuple flatten tvm.tensor.Tensor to a tuple
...@@ -191,6 +191,8 @@ def args_to_workload(x): ...@@ -191,6 +191,8 @@ def args_to_workload(x):
---------- ----------
x: primitive hashable types or tensor.Tensor x: primitive hashable types or tensor.Tensor
The original value The original value
topi_compute_func: topi compute function
The function name will be added as first element of the workload tuple
Returns Returns
------- -------
...@@ -198,18 +200,19 @@ def args_to_workload(x): ...@@ -198,18 +200,19 @@ def args_to_workload(x):
The hashable value The hashable value
""" """
if isinstance(x, tensor.Tensor): if isinstance(x, tensor.Tensor):
return get_const_tuple(x.shape) + (x.dtype, ) workload = get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)): elif isinstance(x, (tuple, list, container.Array)):
return tuple([args_to_workload(a) for a in x]) workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)): elif isinstance(x, (str, int, float, np.int, np.float)):
return x workload = x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value workload = x.value
elif x is None: elif x is None:
return 0 workload = 0
else: else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use' raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types only' % type(x)) 'primitive types only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload
def template(func): def template(func):
""" """
......
# pylint: disable=unused-variable,invalid-name # pylint: disable=unused-variable,invalid-name,unused-argument
""" """
Decorators for registering tunable templates to TOPI. Decorators for registering tunable templates to TOPI.
...@@ -13,7 +13,6 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. ...@@ -13,7 +13,6 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
from ... import _api_internal, tensor from ... import _api_internal, tensor
from ..util import get_func_name
from .task import args_to_workload, dispatcher from .task import args_to_workload, dispatcher
...@@ -55,8 +54,6 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None): ...@@ -55,8 +54,6 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
-------- --------
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
""" """
fname = get_func_name(topi_compute)
def _decorator(f): def _decorator(f):
targets = [target_keys] if isinstance(target_keys, str) else target_keys targets = [target_keys] if isinstance(target_keys, str) else target_keys
for target_key in targets: for target_key in targets:
...@@ -68,7 +65,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None): ...@@ -68,7 +65,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
def config_dispatcher(*args, **kwargs): def config_dispatcher(*args, **kwargs):
"""override topi call as a config dispatcher""" """override topi call as a config dispatcher"""
assert not kwargs, "Do not support kwargs in template function call" assert not kwargs, "Do not support kwargs in template function call"
return (fname, ) + args_to_workload(args) return args_to_workload(args, topi_compute)
_REGISTED_DISPATHCER[target_key][topi_compute] = config_dispatcher _REGISTED_DISPATHCER[target_key][topi_compute] = config_dispatcher
config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_compute] config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_compute]
...@@ -88,7 +85,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None): ...@@ -88,7 +85,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
attrs = {} attrs = {}
for k, v in node.op.attrs.items(): for k, v in node.op.attrs.items():
attrs[k] = v attrs[k] = v
attrs['workload'] = (fname, ) + args_to_workload(args) attrs['workload'] = args_to_workload(args, topi_compute)
if isinstance(op, tensor.ComputeOp): if isinstance(op, tensor.ComputeOp):
op = _api_internal._ComputeOp( op = _api_internal._ComputeOp(
op.name, op.tag, attrs, op.axis, op.body) op.name, op.tag, attrs, op.axis, op.body)
...@@ -153,7 +150,7 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None) ...@@ -153,7 +150,7 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None)
if topi_schedule not in _REGISTED_DISPATHCER[target_key]: if topi_schedule not in _REGISTED_DISPATHCER[target_key]:
@topi_schedule.register(target_key) @topi_schedule.register(target_key)
@dispatcher @dispatcher
def config_dispatcher(outs): def config_dispatcher(outs, *args, **kwargs):
"""override topi call as a workload dispatcher""" """override topi call as a workload dispatcher"""
def traverse(tensors): def traverse(tensors):
"""traverse all ops to find attached workload""" """traverse all ops to find attached workload"""
...@@ -179,11 +176,11 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None) ...@@ -179,11 +176,11 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None)
config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_schedule] config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_schedule]
@config_dispatcher.register(template_keys) @config_dispatcher.register(template_keys)
def template_call(cfg, outs): def template_call(cfg, outs, *args, **kwargs):
"""call the schedule func""" """call the schedule func"""
if f == topi_schedule.fdefault: if f == topi_schedule.fdefault:
return f(outs) return f(outs, *args, **kwargs)
return f(cfg, outs) return f(cfg, outs, *args, **kwargs)
return f return f
......
...@@ -55,33 +55,15 @@ def schedule_conv2d_nhwc(outs): ...@@ -55,33 +55,15 @@ def schedule_conv2d_nhwc(outs):
@tvm.target.generic_func @tvm.target.generic_func
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, def schedule_conv2d_NCHWc(outs):
padding, layout, out_layout, outs):
"""Schedule for conv2d_NCHW[x]c """Schedule for conv2d_NCHW[x]c
Parameters Parameters
---------- ----------
num_filter : int
The number of filter, i.e., the output channel.
kernel_size : tuple of int
(kernel_height, kernel_width)
strides : tuple of int
(stride_of_height, stride_of_width)
padding : tuple of int
(pad_of_height, pad_of_width)
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor outs : Array of Tensor
The computation graph description of conv2d_NCHWc The computation graph description of conv2d_NCHWc
in the format of an array of tensors. in the format of an array of tensors.
The number of filter, i.e., the output channel.
Returns Returns
------- -------
......
...@@ -73,30 +73,11 @@ def schedule_conv2d_nhwc(outs): ...@@ -73,30 +73,11 @@ def schedule_conv2d_nhwc(outs):
@generic.schedule_conv2d_NCHWc.register(["hls"]) @generic.schedule_conv2d_NCHWc.register(["hls"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, def schedule_conv2d_NCHWc(outs):
padding, layout, out_layout, outs):
"""Schedule for conv2d_NCHW[x]c """Schedule for conv2d_NCHW[x]c
Parameters Parameters
---------- ----------
num_filter : int
The number of filter, i.e., the output channel.
kernel_size : tuple of int
(kernel_height, kernel_width)
strides : tuple of int
(stride_of_height, stride_of_width)
padding : tuple of int
(pad_of_height, pad_of_width)
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor outs : Array of Tensor
The computation graph description of conv2d_NCHWc The computation graph description of conv2d_NCHWc
in the format of an array of tensors. in the format of an array of tensors.
......
...@@ -61,8 +61,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -61,8 +61,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
@conv2d_NCHWc.register(["intel_graphics"]) @conv2d_NCHWc.register(["intel_graphics"])
def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout,\ def _decl_conv2d(data, kernel, stride, padding, layout, out_layout, out_dtype='float32'):
out_layout, out_dtype='float32'):
"""Conv2D operator for Intel Graphics backend. """Conv2D operator for Intel Graphics backend.
Parameters Parameters
...@@ -101,7 +100,7 @@ def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout, ...@@ -101,7 +100,7 @@ def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout,
return _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype) return _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype)
@generic.schedule_conv2d_NCHWc.register(["intel_graphics"]) @generic.schedule_conv2d_NCHWc.register(["intel_graphics"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_layout, outs): def schedule_conv2d_NCHWc(outs):
"""Schedule for conv2d_nchw for Intel Graphics """Schedule for conv2d_nchw for Intel Graphics
Parameters Parameters
......
...@@ -84,32 +84,6 @@ def _get_workload(data, kernel, stride, padding, out_dtype): ...@@ -84,32 +84,6 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
'{} vs. {}".format(data.dtype, kernel.dtype) '{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _get_workload_int8(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func
def _get_alter_layout_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule for conv2d_alter_layout. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
@tvm.target.generic_func @tvm.target.generic_func
def _get_schedule(wkl): def _get_schedule(wkl):
...@@ -122,28 +96,6 @@ def _get_schedule(wkl): ...@@ -122,28 +96,6 @@ def _get_schedule(wkl):
return wkl return wkl
@tvm.target.generic_func
def _get_schedule_NCHWc(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
@tvm.target.generic_func
def _get_schedule_NCHWc_int8(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None): def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
"""Convolution operator in NCHW layout. """Convolution operator in NCHW layout.
...@@ -302,8 +254,7 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'): ...@@ -302,8 +254,7 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
@tvm.target.generic_func @tvm.target.generic_func
def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, def conv2d_NCHWc(data, kernel, stride, padding, layout, out_layout, out_dtype='float32'):
padding, layout, out_layout, out_dtype='float32'):
"""Conv2D operator for nChw[x]c layout. """Conv2D operator for nChw[x]c layout.
Parameters Parameters
...@@ -316,12 +267,6 @@ def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, ...@@ -316,12 +267,6 @@ def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride,
[num_filter_chunk, in_channel_chunk, filter_height, filter_width, [num_filter_chunk, in_channel_chunk, filter_height, filter_width,
in_channel_block, num_filter_block] in_channel_block, num_filter_block]
num_filter : int
number of filters, i.e., output channel size
kernel_size : tuple of two ints
[kernel_height, kernel_width]
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
......
...@@ -2,203 +2,15 @@ ...@@ -2,203 +2,15 @@
"""Conv2D schedule on x86""" """Conv2D schedule on x86"""
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.task.dispatcher import ApplyGraphBest
from tvm.autotvm.task.nnvm_integration import deserialize_args from tvm.autotvm.task.nnvm_integration import deserialize_args
from tvm.autotvm.task import register, get_config from tvm.autotvm.task import register, get_config
from .. import generic, tag from .. import generic, tag
from .. import nn from .. import nn
from ..util import get_const_tuple from ..util import get_const_tuple
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \ from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, _get_workload
_get_workload_int8, _get_schedule, _get_schedule_NCHWc, \
_get_schedule_NCHWc_int8, _get_alter_layout_schedule, Workload
from ..nn.pad import pad from ..nn.pad import pad
from . import conv2d_avx_1x1, conv2d_avx_common from . import conv2d_avx_1x1, conv2d_avx_common
from .conv2d_avx_common import AVXConvCommonFwd
from .conv2d_avx_1x1 import AVXConv1x1Fwd
from .check_targets import check_skylake
@_get_schedule.register("cpu")
def _get_schedule_conv(wkl):
_WORKLOADS_AVX = [
# workloads of resnet18_v1 on imagenet
Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('float32', 'float32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
Workload('float32', 'float32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1),
# workloads of resnet101_v1 on imagenet, no extra workload required
# workloads of resnet152_v1 on imagenet, no extra workload required
# workloads of resnet18_v2 on imagenet, no extra workload required
# workloads of resnet34_v2 on imagenet, no extra workload required
]
fp32_vec_len = 8
target = tvm.target.current_target(allow_none=False)
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
fp32_vec_len = 16
_SCHEDULES_AVX = [
# workloads of resnet18_v1 on imagenet
AVXConvCommonFwd(3, fp32_vec_len, 28, False),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
# workloads of resnet101_v1 on imagenet, no extra workload required
# workloads of resnet152_v1 on imagenet, no extra workload required
# workloads of resnet18_v2 on imagenet, no extra workload required
# workloads of resnet34_v2 on imagenet, no extra workload required
]
if wkl not in _WORKLOADS_AVX:
if wkl.hkernel == 1 and wkl.wkernel == 1:
return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len)
return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len)
idx = _WORKLOADS_AVX.index(wkl)
sch = _SCHEDULES_AVX[idx]
return sch
def _get_schedule_conv_int8(wkl):
_WORKLOADS_AVX = [
## Following are for INT8 kernels
Workload('uint8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('uint8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('uint8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('uint8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('uint8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('uint8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('uint8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
Workload('uint8', 'int32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1),
]
fp32_vec_len = 8
target = tvm.target.current_target(allow_none=False)
if check_skylake(target):
fp32_vec_len = 16
_SCHEDULES_AVX = [
# Following are for INT8 operations
# workloads of resnet18_v1 on imagenet
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
# workloads of resnet101_v1 on imagenet, no extra workload required
# workloads of resnet152_v1 on imagenet, no extra workload required
# workloads of resnet18_v2 on imagenet, no extra workload required
# workloads of resnet34_v2 on imagenet, no extra workload required
]
if wkl not in _WORKLOADS_AVX:
if wkl.hkernel == 1 and wkl.wkernel == 1:
return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len)
return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len)
idx = _WORKLOADS_AVX.index(wkl)
sch = _SCHEDULES_AVX[idx]
return sch
@_get_schedule_NCHWc.register("cpu")
def _get_schedule_NCHWc_x86(wkl, layout, out_layout):
return _get_schedule_conv(wkl)
@_get_schedule_NCHWc_int8.register("cpu")
def _get_schedule_NCHWc_x86_int8(wkl, layout, out_layout):
return _get_schedule_conv_int8(wkl)
@_get_alter_layout_schedule.register("cpu")
def _get_alter_layout_schedule_x86(wkl):
return _get_schedule_conv(wkl)
def _get_fp32_len(): def _get_fp32_len():
fp32_vec_len = 8 fp32_vec_len = 8
...@@ -210,18 +22,23 @@ def _get_fp32_len(): ...@@ -210,18 +22,23 @@ def _get_fp32_len():
return fp32_vec_len return fp32_vec_len
def _get_default_sch(workload): def _get_default_config(cfg, workload):
"""
Get default schedule config for the workload
Parameters
----------
workload : topi.nn.conv2d.Workload
Convolution workload
"""
fp32_vec_len = _get_fp32_len() fp32_vec_len = _get_fp32_len()
_, _, kh, kw, _ = workload[2] is_kernel_1x1 = workload.hkernel == 1 and workload.wkernel == 1
is_kernel_1x1 = kh == 1 and kw == 1
if is_kernel_1x1: if is_kernel_1x1:
cfg = conv2d_avx_1x1._fallback_schedule(workload, fp32_vec_len) conv2d_avx_1x1._fallback_schedule(cfg, workload, fp32_vec_len)
else: else:
cfg = conv2d_avx_common._fallback_schedule(workload, fp32_vec_len) conv2d_avx_common._fallback_schedule(cfg, workload, fp32_vec_len)
return cfg
def _create_schedule_template(cfg, data, kernel, strides, padding, layout): def _create_tuning_space(cfg, data, kernel, strides, padding, layout):
"""Create schedule configuration from input arguments""" """Create schedule configuration from input arguments"""
dshape = get_const_tuple(data.shape) dshape = get_const_tuple(data.shape)
kshape = get_const_tuple(kernel.shape) kshape = get_const_tuple(kernel.shape)
...@@ -247,38 +64,17 @@ def _create_schedule_template(cfg, data, kernel, strides, padding, layout): ...@@ -247,38 +64,17 @@ def _create_schedule_template(cfg, data, kernel, strides, padding, layout):
cfg.define_knob("unroll_kw", [True, False]) cfg.define_knob("unroll_kw", [True, False])
def conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype): @autotvm.register_topi_compute(conv2d, 'cpu', 'direct')
"""convert argument to workload"""
if len(kernel.shape) == 4:
raw_kernel = kernel
else: # the input kernel is transformed by alter_op_layout
shape = get_const_tuple(kernel.shape)
raw_kernel = tvm.placeholder((shape[0] * shape[4], shape[1], shape[2], shape[3]),
dtype=kernel.dtype)
return ('conv2d', ) + autotvm.task.args_to_workload(
[data, raw_kernel, strides, padding, layout, out_dtype])
@conv2d.register("cpu")
@autotvm.task.dispatcher
def conv2d_x86(data, kernel, strides, padding, layout, out_dtype):
"""x86 conv2d declaration."""
return conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
@conv2d_x86.register(["direct"])
def _declaration_conv(cfg, data, kernel, strides, padding, layout, out_dtype): def _declaration_conv(cfg, data, kernel, strides, padding, layout, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype out_dtype = data.dtype if out_dtype is None else out_dtype
padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) padding = padding if isinstance(padding, (tuple, list)) else (padding, padding)
strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) strides = strides if isinstance(strides, (tuple, list)) else (strides, strides)
if layout == 'NCHW': if layout == 'NCHW':
_create_schedule_template(cfg, data, kernel, strides, padding, layout) _create_tuning_space(cfg, data, kernel, strides, padding, layout)
if cfg.is_fallback: if cfg.is_fallback:
workload = conv_arg_to_workload(data, kernel, strides, padding, wkl = _get_workload(data, kernel, strides, padding, out_dtype)
layout, out_dtype) _get_default_config(cfg, wkl)
cfg = _get_default_sch(workload) return _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtype)
args = [cfg, data, kernel, strides, padding, layout, out_dtype]
return _declaration_conv_impl(*args)
elif layout == 'HWCN': elif layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, strides, padding, out_dtype) return nn.conv2d_hwcn(data, kernel, strides, padding, out_dtype)
elif layout == 'NHWC': elif layout == 'NHWC':
...@@ -345,11 +141,7 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtyp ...@@ -345,11 +141,7 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtyp
lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn] lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn]
.astype(out_dtype), .astype(out_dtype),
name='output_unpack', name='output_unpack',
tag='conv2d_nchw', tag='conv2d_nchw')
attrs={'workload':
conv_arg_to_workload(data, kernel, strides,
padding, layout,
out_dtype)})
return unpack return unpack
...@@ -385,18 +177,7 @@ def schedule_conv2d(cfg, outs): ...@@ -385,18 +177,7 @@ def schedule_conv2d(cfg, outs):
_, _, kh, kw = get_const_tuple(kernel.shape) _, _, kh, kw = get_const_tuple(kernel.shape)
is_kernel_1x1 = kh == 1 and kw == 1 is_kernel_1x1 = kh == 1 and kw == 1
current_cfg = cfg args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
if cfg.is_fallback:
workload_attr = op.attrs["workload"]
strides = (int(workload_attr[3][0].value), int(workload_attr[3][1].value))
padding = (int(workload_attr[4][0].value), int(workload_attr[4][1].value))
layout = workload_attr[5].value
out_dtype = workload_attr[6].value
workload = conv_arg_to_workload(data, kernel, strides, padding,
layout, out_dtype)
current_cfg = _get_default_sch(workload)
args = [s, current_cfg, data, data_pad, data_vec, kernel_vec, conv_out,
output, outs[0]]
if is_kernel_1x1: if is_kernel_1x1:
conv2d_avx_1x1._schedule_conv(*args) conv2d_avx_1x1._schedule_conv(*args)
else: else:
...@@ -470,17 +251,13 @@ def schedule_conv2d_nhwc(outs): ...@@ -470,17 +251,13 @@ def schedule_conv2d_nhwc(outs):
@register("topi_x86_conv2d_NCHWc") @register("topi_x86_conv2d_NCHWc")
def _topi_nn_conv2d_NCHWc(*args, **kwargs): def _topi_nn_conv2d_NCHWc(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call" assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args) data, kernel, strides, padding, origin_layout, dtype = deserialize_args(args)
data, kernel = args[:2]
strides = args[4]
padding = args[5]
layout = args[6]
raw_data_shape = get_const_tuple(data.shape) raw_data_shape = get_const_tuple(data.shape)
raw_kernel_shape = get_const_tuple(kernel.shape) raw_kernel_shape = get_const_tuple(kernel.shape)
# get config here # get config here
cfg = get_config() cfg = get_config()
_create_schedule_template(cfg, data, kernel, strides, padding, layout) _create_tuning_space(cfg, data, kernel, strides, padding, origin_layout)
# change shape with the value in config # change shape with the value in config
ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
...@@ -491,50 +268,13 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs): ...@@ -491,50 +268,13 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
out_layout = "NCHW%dc" % oc_bn out_layout = "NCHW%dc" % oc_bn
new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn, new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn) raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
args[0] = tvm.placeholder(new_data_shape, data.dtype) new_data = tvm.placeholder(new_data_shape, data.dtype)
args[1] = tvm.placeholder(new_kernel_shape, kernel.dtype) new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
args[6] = data_layout
args[7] = out_layout
C = _declaration_conv_NCHWc(cfg, *args, **kwargs) C = _declaration_conv_NCHWc(cfg, new_data, new_kernel, strides, padding,
s = _schedule_conv2d_NCHWc(cfg, args[2], args[3], args[4], args[5], data_layout, out_layout, dtype)
args[6], args[7], [C]) s = _schedule_conv2d_NCHWc(cfg, [C])
return s, [args[0], args[1], C] return s, [new_data, new_kernel, C]
def conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides,
padding, layout, out_layout, out_dtype):
"""convert argument to workload"""
dshape = get_const_tuple(data.shape)
kshape = get_const_tuple(kernel.shape)
if len(dshape) > 4:
raw_data = tvm.placeholder((dshape[0], dshape[1] * dshape[4], dshape[2],
dshape[3]), dtype=kernel.dtype)
else:
raw_data = data
if len(kshape) > 4:
raw_kernel = tvm.placeholder((kshape[0] * kshape[5], kshape[1] * kshape[4],
kshape[2], kshape[3]), dtype=kernel.dtype)
else:
raw_kernel = kernel
return ('conv2d_NCHWc', ) + autotvm.task.args_to_workload(
[raw_data, raw_kernel, strides, padding, layout, out_layout,
out_dtype])
def _query_dispatcher(workload, in_alter_op=False):
dispatch_ctx = autotvm.task.DispatchContext.current
if isinstance(dispatch_ctx, ApplyGraphBest):
if in_alter_op:
cfg = dispatch_ctx.query(None, None)
else:
cfg = dispatch_ctx.query_global_dict(workload)
else:
target = tvm.target.current_target()
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback:
cfg = _get_default_sch(workload)
return cfg
@conv2d_alter_layout.register("cpu") @conv2d_alter_layout.register("cpu")
...@@ -546,87 +286,72 @@ def _alter_conv2d_layout(attrs, inputs, tinfo): ...@@ -546,87 +286,72 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
# only optimize for NCHW, groups=1 conv # only optimize for NCHW, groups=1 conv
if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1: if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1:
return None return None
batch_size, in_channel, height, width = get_const_tuple(data.shape)
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
kernel_size = attrs.get_int_tuple("kernel_size")
padding = attrs.get_int_tuple("padding") padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides") strides = attrs.get_int_tuple("strides")
layout = attrs['layout'] layout = attrs['layout']
out_layout = layout if attrs["out_layout"] == "__undef__" else attrs["out_layout"]
dtype = data.dtype dtype = data.dtype
out_dtype = dtype if attrs["out_dtype"] == "same" else attrs["out_dtype"] out_dtype = dtype if attrs["out_dtype"] == "same" else attrs["out_dtype"]
workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides,
padding, layout, out_layout, out_dtype)
cfg = _query_dispatcher(workload, True)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
new_attrs['layout'] = 'NCHW%dc' % ic_bn
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
# Store global schedule dictionary for ApplyGraphBest dispatcher workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, layout, out_dtype], conv2d)
dispatch_ctx = autotvm.task.DispatchContext.current dispatch_ctx = autotvm.task.DispatchContext.current
if isinstance(dispatch_ctx, ApplyGraphBest): target = tvm.target.current_target()
workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, cfg = dispatch_ctx.query(target, workload)
padding, new_attrs['layout'], if cfg.is_fallback:
new_attrs['out_layout'], out_dtype) wkl = _get_workload(data, kernel, strides, padding, out_dtype)
global_dict_key = workload _get_default_config(cfg, wkl)
dispatch_ctx.update_global_dict(global_dict_key, cfg)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
new_attrs['layout'] = 'NCHW%dc' % ic_bn
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
# Store altered operator's config
new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
dtype=data.dtype)
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, new_attrs['layout'],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
@conv2d_NCHWc.register("cpu") @autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct')
def conv2d_NCHWc_cpu(data, kernel, num_filter, kernel_size, strides, def _declaration_conv_NCHWc(cfg, data, kernel, strides,
padding, layout, out_layout, out_dtype): padding, layout, out_layout, out_dtype):
"""x86 conv2d_NCHWc declaration.""" # layout and out_layout are not used here,
dispatch_ctx = autotvm.task.DispatchContext.current # we keep them for debug convenience when dumping autotvm workload
if not isinstance(dispatch_ctx, ApplyGraphBest): HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding)
layout = out_layout = "NCHW" HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides,
padding, layout, out_layout, out_dtype)
cfg = _query_dispatcher(workload)
return _declaration_conv_NCHWc(cfg, data, kernel, num_filter, kernel_size, strides,
padding, layout, out_layout, out_dtype)
def _declaration_conv_NCHWc(cfg, data, kernel, num_filter, kernel_size, strides,
padding, layout, out_layout, out_dtype):
n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
ic = ic_chunk * ic_block
kh, kw = kernel_size if isinstance(kernel_size, (tuple, list)) else \
(kernel_size, kernel_size)
is_kernel_1x1 = kh == 1 and kw == 1
ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding)
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
if data.dtype == 'uint8': if data.dtype == 'uint8':
wkl = _get_workload_int8(tvm.placeholder((n, ic, h, w), dtype=data.dtype), oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape)
tvm.placeholder((num_filter, ic, kh, kw), else:
oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
# get workload and related schedule config
wkl = _get_workload(tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
dtype=kernel.dtype), dtype=kernel.dtype),
strides, padding, out_dtype) strides, padding, out_dtype)
sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout) if cfg.is_fallback:
return conv2d_avx_1x1._declaration_conv_NCHWc_int8(wkl, sch, data, kernel) \ _get_default_config(cfg, wkl)
if is_kernel_1x1 \
else conv2d_avx_common._declaration_conv_NCHWc_int8(wkl, sch, data, kernel)
args = [cfg, data, kernel, (kh, kw), (sh, sw), (ph, pw), layout, out_layout, out_dtype]
return _declaration_conv_NCHWc_impl(*args)
def _declaration_conv_NCHWc_impl(cfg, data, kernel, kernel_size, strides, padding, layout,
out_layout, out_dtype):
HPAD, WPAD = padding
HSTR, WSTR = strides
n, ic_chunk, ih, iw, ic_block = get_const_tuple(data.shape) # output shape
ic = ic_chunk * ic_block out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1
kh, kw = kernel_size out_width = (iw + 2 * WPAD - kernel_width) // WSTR + 1
oc_chunk, _, _, _, _, oc_block = get_const_tuple(kernel.shape) oshape = (n, oc_chunk, out_height, out_width, oc_bn)
oc = oc_chunk * oc_block
oh = (ih + 2 * HPAD - kh) // HSTR + 1
ow = (iw + 2 * WPAD - kw) // WSTR + 1
# DOPAD # DOPAD
DOPAD = (HPAD != 0 or WPAD != 0) DOPAD = (HPAD != 0 or WPAD != 0)
...@@ -635,51 +360,43 @@ def _declaration_conv_NCHWc_impl(cfg, data, kernel, kernel_size, strides, paddin ...@@ -635,51 +360,43 @@ def _declaration_conv_NCHWc_impl(cfg, data, kernel, kernel_size, strides, paddin
else: else:
data_pad = data data_pad = data
# fetch schedule ic = tvm.reduce_axis((0, in_channel), name='ic')
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] kh = tvm.reduce_axis((0, kernel_height), name='kh')
if ic_bn != ic_block: kw = tvm.reduce_axis((0, kernel_width), name='kw')
raise RuntimeError("ic_bn in config is not equal to actual data ic_block: %d vs %d."
% (ic_bn, ic_block))
if oc_bn != oc_block:
raise RuntimeError("oc_bn in config is not equal to actual kernel oc_block: %d vs %d."
% (oc_bn, oc_block))
# convolution
oshape = (n, oc//oc_bn, oh, ow, oc_bn)
ic = tvm.reduce_axis((0, ic), name='ic')
kh = tvm.reduce_axis((0, kernel_size[0]), name='kh')
kw = tvm.reduce_axis((0, kernel_size[1]), name='kw')
workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, if data.dtype == 'uint8':
strides, padding, layout, assert out_dtype == "int32", \
out_layout, out_dtype), "INT8 convolution requires input dtype = uint8 and output dtype=int32"
attrs = {'workload': workload} # Intel performs dot product of 2 "4" Int8 values
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: # Current implementation requires ic_bn to be a multiple of 4
n_elems = 4
assert ic_bn % n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//ic_bn), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner]
.astype(out_dtype) *
kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
# else: fp implementation
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw,
ic%ic_bn].astype(out_dtype) * ic%ic_bn].astype(out_dtype) *
kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block], kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block],
axis=[ic, kh, kw]), axis=[ic, kh, kw]),
name='conv2d_NCHWc', tag="conv2d_NCHWc", attrs=attrs) name='conv2d_NCHWc', tag="conv2d_NCHWc")
return conv
@generic.schedule_conv2d_NCHWc.register("cpu")
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, padding,
layout, out_layout, outs):
"""x86 conv2d_NCHWc schedule"""
return _schedule_conv2d_NCHWc(None, num_filter, kernel_size, strides, padding,
layout, out_layout, outs)
def _schedule_conv2d_NCHWc(cfg, num_filter, kernel_size, strides, padding, @autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'cpu', ['direct'])
layout, out_layout, outs): def _schedule_conv2d_NCHWc(cfg, outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = [] scheduled_ops = []
dispatch_ctx = autotvm.task.DispatchContext.current
if not isinstance(dispatch_ctx, ApplyGraphBest):
layout = out_layout = "NCHW"
def traverse(op): def traverse(op):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
...@@ -702,34 +419,17 @@ def _schedule_conv2d_NCHWc(cfg, num_filter, kernel_size, strides, padding, ...@@ -702,34 +419,17 @@ def _schedule_conv2d_NCHWc(cfg, num_filter, kernel_size, strides, padding,
data_pad = data data_pad = data
data = data_pad.op.input_tensors[0] data = data_pad.op.input_tensors[0]
kh, kw = kernel_size if isinstance(kernel_size, (tuple, list)) else \ args = [s, cfg, data_vec, conv_out, outs[0]]
(kernel_size, kernel_size)
is_kernel_1x1 = kh == 1 and kw == 1
n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
ic = ic_chunk * ic_block
original_data = tvm.placeholder((n, ic, h, w), dtype=data.dtype)
kh, kw = kernel_size
original_kernel = tvm.placeholder((num_filter, ic, kh, kw),
dtype=kernel.dtype)
if data.dtype == 'uint8': if data.dtype == 'uint8':
wkl = _get_workload_int8(original_data, original_kernel, # int8 conv kernel is 7-dim
strides, padding, conv_out.dtype) _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout) if kh == 1 and kw == 1:
args = [s, wkl, sch, data_vec, kernel, conv_out, outs[0]]
if is_kernel_1x1:
conv2d_avx_1x1._schedule_conv_NCHWc_int8(*args) conv2d_avx_1x1._schedule_conv_NCHWc_int8(*args)
else: else:
conv2d_avx_common._schedule_conv_NCHWc_int8(*args) conv2d_avx_common._schedule_conv_NCHWc_int8(*args)
else: else:
current_cfg = cfg _, _, kh, kw, _, _, = get_const_tuple(kernel.shape)
if current_cfg is None: if kh == 1 and kw == 1:
workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides,
padding, layout, out_layout,
conv_out.dtype)
current_cfg = _query_dispatcher(workload)
args = [s, current_cfg, data_vec, conv_out, outs[0]]
if is_kernel_1x1:
conv2d_avx_1x1._schedule_conv_NCHWc(*args) conv2d_avx_1x1._schedule_conv_NCHWc(*args)
else: else:
conv2d_avx_common._schedule_conv_NCHWc(*args) conv2d_avx_common._schedule_conv_NCHWc(*args)
......
# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name # pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
"""1x1 Conv2D schedule on for Intel CPU""" """1x1 Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm import tvm
from tvm.autotvm.task import ConfigEntity from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
import topi
from ..nn.util import infer_pad from ..nn.util import infer_pad
from ..nn.pad import pad from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake from .check_targets import check_skylake
AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor']) def _fallback_schedule(cfg, wkl, simd_width):
def _get_default_schedule(wkl, simd_width):
HPAD, WPAD = wkl.hpad, wkl.wpad HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
...@@ -37,45 +31,11 @@ def _get_default_schedule(wkl, simd_width): ...@@ -37,45 +31,11 @@ def _get_default_schedule(wkl, simd_width):
if out_width % ow_factor == 0: if out_width % ow_factor == 0:
for oh_factor in range(out_height, 0, -1): for oh_factor in range(out_height, 0, -1):
if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
return AVXConv1x1Fwd(ic_bn, oc_bn, oh_factor, ow_factor) cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) cfg["tile_oh"] = OtherOptionEntity(oh_factor)
cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor])
return
def _fallback_schedule(wkl, simd_width):
batch_size, in_channel, height, width, _ = wkl[1]
out_channel, _, hkernel, wkernel, _ = wkl[2]
HPAD, WPAD = wkl[4]
HSTR, WSTR = wkl[3]
out_height = (height + 2 * HPAD - hkernel) // HSTR + 1
out_width = (width + 2 * WPAD - wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if out_channel % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if in_channel % bn == 0:
ic_bn = bn
break
for ow_factor in range(out_width, 0, -1):
if out_width % ow_factor == 0:
for oh_factor in range(out_height, 0, -1):
if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [in_channel // ic_bn, ic_bn]],
["tile_oc", "sp", [out_channel // oc_bn, oc_bn]],
["tile_oh", "ot", oh_factor],
["tile_ow", "sp", [out_width // ow_factor,
ow_factor]],],
"t": ""}
return ConfigEntity.from_json_dict(cfg_dict)
raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
...@@ -148,8 +108,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu ...@@ -148,8 +108,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu
def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
# fetch schedule # fetch schedule
ic_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oh"].val, oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
cfg["tile_ow"].size[-1]) _, _, _, _, ic_bn = get_const_tuple(data.shape)
# schedule data # schedule data
A = data A = data
...@@ -201,57 +161,13 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): ...@@ -201,57 +161,13 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
return s return s
def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel): def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
""" Declaration for int8 conv"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
# Intel performs dot product of 2 "4" Int8 values
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
# Reshaping kernel as the last 2 dimensions are 1x1 (k_h x k_w)
k_shape = kernel.shape
kernel = topi.reshape(kernel, (k_shape[0], k_shape[1], k_shape[2], k_shape[3],
k_shape[4] * k_shape[5] * k_shape[6]))
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR, ow*WSTR,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")
return conv
def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
""" """
Defines the schedule for INT8 for intel machines Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations Uses the Intel intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/ More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training lower-numerical-precision-deep-learning-inference-and-training
""" """
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
int32_lanes = -1 int32_lanes = -1
if check_skylake(target): if check_skylake(target):
...@@ -260,6 +176,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -260,6 +176,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
return s return s
assert int32_lanes != -1 assert int32_lanes != -1
oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
_, _, _, _, ic_bn = get_const_tuple(data.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
# schedule data # schedule data
A = data A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp): if isinstance(s[A].op, tvm.tensor.ComputeOp):
...@@ -271,8 +191,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -271,8 +191,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
CC = s.cache_write(C, 'global') CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor) ow_outer, ow_inner = s[C].split(ow, factor=ow_factor)
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block) s[C].vectorize(oc_block)
...@@ -282,17 +202,17 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -282,17 +202,17 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
s[C].parallel(parallel_axis) s[C].parallel(parallel_axis)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
# Skylake and future processors have 16 vector lanes # Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0 assert oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_outer, ic_f_inner, oh_inner, s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner,
ow_inner, oc_f_inner, oc_s_inner, ic_s_inner) ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].fuse(oc_chunk, oh_outer) s[CC].fuse(oc_chunk, oh_outer)
...@@ -303,8 +223,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -303,8 +223,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
if C != O: if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer) parallel_axis = s[O].fuse(oc_chunk, oh_outer)
......
# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name # pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
"""Conv2D schedule on for Intel CPU""" """Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm import tvm
from tvm.autotvm.task import ConfigEntity from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.util import infer_pad from ..nn.util import infer_pad
from ..nn.pad import pad from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake from .check_targets import check_skylake
AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw']) def _fallback_schedule(cfg, wkl, simd_width):
def _get_default_schedule(wkl, simd_width):
HPAD, WPAD = wkl.hpad, wkl.wpad HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
...@@ -36,42 +32,10 @@ def _get_default_schedule(wkl, simd_width): ...@@ -36,42 +32,10 @@ def _get_default_schedule(wkl, simd_width):
reg_n = n reg_n = n
break break
return AVXConvCommonFwd(ic_bn, oc_bn, reg_n, False) cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
def _fallback_schedule(wkl, simd_width): cfg["unroll_kw"] = OtherOptionEntity(False)
batch_size, in_channel, height, width, _ = wkl[1]
out_channel, _, hkernel, wkernel, _ = wkl[2]
HPAD, WPAD = wkl[4]
HSTR, WSTR = wkl[3]
out_width = (width + 2 * WPAD - wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if out_channel % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if in_channel % bn == 0:
ic_bn = bn
break
reg_n = 1
for n in range(31, 0, -1):
if out_width % n == 0:
reg_n = n
break
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [in_channel // ic_bn, ic_bn]],
["tile_oc", "sp", [out_channel // oc_bn, oc_bn]],
["tile_ow", "sp", [out_width // reg_n, reg_n]],
["unroll_kw", "ot", False]],
"t": ""}
return ConfigEntity.from_json_dict(cfg_dict)
def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
...@@ -147,8 +111,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu ...@@ -147,8 +111,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu
def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
# fetch schedule # fetch schedule
ic_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_ow"].size[-1], reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
cfg["unroll_kw"].val) _, _, _, _, ic_bn = get_const_tuple(data.shape)
# schedule data # schedule data
A = data A = data
...@@ -197,52 +161,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): ...@@ -197,52 +161,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
return s return s
def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel): def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
"""
This function sets up the compute for INT8 conv 2d
Inputs are in INT8 datatype
Output is in INT32 datatype
"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
# pack data
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
# convolution
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
kh = tvm.reduce_axis((0, wkl.hkernel), name='kh')
kw = tvm.reduce_axis((0, wkl.wkernel), name='kw')
# Intel performs dot product of 2 "4" Int8 values
# Current implementation requires ic_bn to be a multiple of 4
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")
return conv
def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
""" """
Defines the schedule for INT8 for intel machines Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations Uses the Intel intrinsics to use INT8 operations
...@@ -263,6 +182,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -263,6 +182,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
return s return s
assert int32_lanes != -1 assert int32_lanes != -1
reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
_, _, _, _, ic_bn = get_const_tuple(data.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
A = data A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp): if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, _ = s[A].op.axis batch, ic_chunk, ih, iw, _ = s[A].op.axis
...@@ -274,7 +197,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -274,7 +197,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
CC = s.cache_write(C, 'global') CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis _, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n) ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh) parallel_axis = s[C].fuse(oc_chunk, oh)
s[C].vectorize(oc_block) s[C].vectorize(oc_block)
...@@ -285,14 +208,14 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -285,14 +208,14 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n) ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
# Skylake and future processors have 16 vector lanes # Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0 assert oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
if sch.unroll_kw: if unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw, s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner) ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].unroll(kw) s[CC].unroll(kw)
...@@ -308,7 +231,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -308,7 +231,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
if C != O: if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh) parallel_axis = s[O].fuse(oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis) s[C].compute_at(s[O], parallel_axis)
......
...@@ -54,19 +54,11 @@ def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad, ...@@ -54,19 +54,11 @@ def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad,
data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES) data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
if out_dtype == 'int32': if out_dtype == 'int32':
if k_h != 1:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES//4, NUM_VEC_LANES, 4) NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
else:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES//4,
NUM_VEC_LANES, 4, k_h, k_w)
elif out_dtype == 'float32': elif out_dtype == 'float32':
if k_h != 1:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES, NUM_VEC_LANES) NUM_VEC_LANES, NUM_VEC_LANES)
else:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES,
NUM_VEC_LANES, k_h, k_w)
out_height = (im_height + 2 * hpad - k_h) // hstride + 1 out_height = (im_height + 2 * hpad - k_h) // hstride + 1
out_width = (im_width + 2 * wpad - k_w) // wstride + 1 out_width = (im_width + 2 * wpad - k_w) // wstride + 1
o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES) o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
...@@ -103,8 +95,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f ...@@ -103,8 +95,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
with tvm.target.create(TARGET_NAME): with tvm.target.create(TARGET_NAME):
conv = topi.nn.conv2d_NCHWc(data, kernel, num_filter=out_filter, conv = topi.nn.conv2d_NCHWc(data, kernel, stride=hstride,
kernel_size=(k_h, k_w), stride=hstride,
padding=hpad, layout='NCHWc', padding=hpad, layout='NCHWc',
out_layout='NCHWc', out_dtype=out_dtype) out_layout='NCHWc', out_dtype=out_dtype)
out = topi.nn.relu(conv) out = topi.nn.relu(conv)
...@@ -114,13 +105,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f ...@@ -114,13 +105,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True)) LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
# Generate and run the optimized schedule # Generate and run the optimized schedule
sconv = topi.generic.nn.schedule_conv2d_NCHWc(num_filter=out_filter, sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out])
kernel_size=(k_h, k_w),
strides=hstride,
padding=hpad,
layout='NCHWc',
out_layout='NCHWc',
outs=[out])
func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv') func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv')
func(data_array, kernel_array, c_sch) func(data_array, kernel_array, c_sch)
......
"""Test for NCHW[x]c convolution"""
import numpy as np
import tvm
from tvm import autotvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend
def _transform_data(data, bn):
# NCHW -> NCHW[x]c
batch_size, channel, height, width = data.shape
data = np.transpose(data, (0, 2, 3, 1))
data = np.reshape(data, (batch_size, height, width, channel//bn, bn))
data = np.transpose(data, (0, 3, 1, 2, 4))
return data
def _transform_kernel(kernel, ic_bn, oc_bn):
# OIHW -> OIHW[x]i[x]o
out_channel, in_channel, kh, kw = kernel.shape
kernel = np.transpose(kernel, (1, 2, 3, 0))
kernel = np.reshape(kernel, (in_channel, kh, kw, out_channel//oc_bn, oc_bn))
kernel = np.transpose(kernel, (1, 2, 3, 4, 0))
kernel = np.reshape(kernel, (kh, kw, out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn))
kernel = np.transpose(kernel, (2, 4, 0, 1, 5, 3))
return kernel
def _transform_bias(bias, bn):
# [num_filter, 1, 1] -> [num_filter//bn, 1, 1, bn]
num_filter, h, w = bias.shape
bias = np.transpose(bias, (1, 2, 0))
bias = np.reshape(bias, (h, w, num_filter//bn, bn))
bias = np.transpose(bias, (2, 0, 1, 3))
return bias
def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"):
assert dilation == 1, "conv2d_NCHWc does not support dilation for now."
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding))
in_height = in_width = in_size
# for testing functionality,
# we choose arbitrary block size that can divide the channel,
# regardless of the performance.
oc_block = 1
for bn in range(16, 0, -1):
if num_filter % bn == 0:
oc_block = bn
break
ic_block = 1
for bn in range(oc_block, 0, -1):
if in_channel % bn == 0:
ic_block = bn
break
A = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A')
W = tvm.placeholder((num_filter//oc_block, in_channel//ic_block, kernel, kernel, ic_block, oc_block), name='W')
bias = tvm.placeholder((num_filter//oc_block, 1, 1, oc_block), name='bias')
@memoize("topi.tests.test_topi_conv2d_NCHWc.verify_conv2d_NCHWc")
def get_ref_data():
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
w_np = np.random.uniform(size=(num_filter, in_channel, kernel, kernel)).astype(dtype)
b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype)
c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
if add_bias:
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \
_transform_bias(b_np, oc_block), _transform_data(c_np, oc_block)
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding),
layout='NCHW%dc'%ic_block,
out_layout="NCHW%dc"%oc_block,
out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_conv2d_NCHWc([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
# test llvm only for now since conv2d_NCHWc implement is missing in other backend.
for device in ["llvm"]:
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
if __name__ == "__main__":
# ResNet18 workloads
verify_conv2d_NCHWc(1, 3, 224, 64, 7, 2, 3)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc(1, 64, 56, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 64, 56, 128, 3, 2, 1)
verify_conv2d_NCHWc(1, 64, 56, 128, 1, 2, 0)
verify_conv2d_NCHWc(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_NCHWc(1, 128, 28, 256, 3, 2, 1)
verify_conv2d_NCHWc(1, 128, 28, 256, 1, 2, 0)
verify_conv2d_NCHWc(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_NCHWc(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_NCHWc(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_NCHWc(1, 512, 7, 512, 3, 1, 1)
# bias, relu
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_relu=True)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
# disable dilation test since it is not supported by NCHW[x]c conv for now.
# verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, dilation=2)
# batch size
verify_conv2d_NCHWc(4, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc(9, 64, 56, 64, 3, 1, 1)
# weird workloads
verify_conv2d_NCHWc(2, 2, 2, 2, 2, 2, 2)
verify_conv2d_NCHWc(3, 3, 3, 3, 3, 3, 3)
verify_conv2d_NCHWc(4, 4, 4, 4, 4, 4, 4)
verify_conv2d_NCHWc(5, 5, 5, 5, 5, 5, 5)
verify_conv2d_NCHWc(6, 6, 6, 6, 6, 6, 6)
# disable these tests due to some bugs of llvm with nvptx
# verify_conv2d_NCHWc(1, 1, 1, 1, 1, 1, 1, dilation=1)
# verify_conv2d_NCHWc(1, 1, 1, 1, 1, 1, 1, dilation=2)
# verify_conv2d_NCHWc(2, 13, 71, 59, 3, 1, 1)
# inception v3 workloads
verify_conv2d_NCHWc(1, 3, 299, 32, 3, 2, 0)
verify_conv2d_NCHWc(1, 32, 149, 32, 3, 1, 0)
verify_conv2d_NCHWc(1, 32, 147, 64, 3, 1, 1)
verify_conv2d_NCHWc(1, 64, 73, 80, 1, 1, 0)
verify_conv2d_NCHWc(1, 80, 73, 192, 3, 1, 0)
verify_conv2d_NCHWc(1, 192, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 192, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc(1, 48, 35, 64, 5, 1, 2)
verify_conv2d_NCHWc(1, 64, 35, 96, 3, 1, 1)
verify_conv2d_NCHWc(1, 96, 35, 96, 3, 1, 1)
verify_conv2d_NCHWc(1, 192, 35, 32, 1, 1, 0)
verify_conv2d_NCHWc(1, 256, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 256, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc(1, 288, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 288, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc(1, 288, 35, 384, 3, 2, 0)
verify_conv2d_NCHWc(1, 96, 35, 96, 3, 2, 0)
verify_conv2d_NCHWc(1, 768, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 768, 17, 128, 1, 1, 0)
verify_conv2d_NCHWc(1, 128, 17, 128, 1, 1, 0)
verify_conv2d_NCHWc(1, 128, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc(1, 128, 17, 128, 7, 1, 3)
verify_conv2d_NCHWc(1, 128, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 768, 17, 160, 1, 1, 0)
verify_conv2d_NCHWc(1, 160, 17, 160, 1, 1, 0)
verify_conv2d_NCHWc(1, 160, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc(1, 160, 17, 160, 7, 1, 3)
verify_conv2d_NCHWc(1, 160, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 192, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 192, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc(1, 192, 17, 320, 3, 2, 0)
verify_conv2d_NCHWc(1, 192, 17, 192, 3, 2, 0)
verify_conv2d_NCHWc(1, 1280, 8, 320, 1, 1, 0)
verify_conv2d_NCHWc(1, 1280, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc(1, 384, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc(1, 384, 8, 384, 3, 1, 1)
verify_conv2d_NCHWc(1, 1280, 8, 448, 1, 1, 0)
verify_conv2d_NCHWc(1, 448, 8, 384, 3, 1, 1)
verify_conv2d_NCHWc(1, 1280, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 320, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 448, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 1024, 19, 84, 3, 1, 1)
verify_conv2d_NCHWc(1, 2048, 10, 126, 3, 1, 1)
verify_conv2d_NCHWc(1, 512, 5, 126, 3, 1, 1)
verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1)
...@@ -14,7 +14,6 @@ import nnvm.compiler ...@@ -14,7 +14,6 @@ import nnvm.compiler
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from topi.x86.conv2d import conv_NCHWc_arg_to_workload
import tvm.contrib.graph_runtime as runtime import tvm.contrib.graph_runtime as runtime
################################################################# #################################################################
...@@ -118,17 +117,9 @@ def tune_kernels(tasks, ...@@ -118,17 +117,9 @@ def tune_kernels(tasks,
prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# converting conv2d tasks to conv2d_NCHWc tasks # converting conv2d tasks to conv2d_NCHWc tasks
# data, kernel are tuples of ("TENSOR", shape, dtype) task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=tsk.args,
data, kernel, strides, padding, layout, dtype = tsk.args target=target, template_key='direct')
kernel_size = (kernel[1][2], kernel[1][3]) task.workload = tsk.workload
data_plc = tvm.placeholder(data[1], name="data")
kernel_plc = tvm.placeholder(kernel[1], name="kernel")
args = [data_plc, kernel_plc, kernel[1][0], kernel_size, strides,
padding, layout, layout, dtype]
args = autotvm.task.nnvm_integration.serialize_args(args)
task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=args, target=target)
task.workload = conv_NCHWc_arg_to_workload(data_plc, kernel_plc, kernel_size,
strides, padding, layout, layout, dtype)
# create tuner # create tuner
if tuner == 'xgb' or tuner == 'xgb-rank': if tuner == 'xgb' or tuner == 'xgb-rank':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment