Commit b9e8826f by Yizhi Liu Committed by Tianqi Chen

Refine porting x86 NCHWc conv to AutoTVM (#1993)

parent e286e637
...@@ -167,16 +167,16 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): ...@@ -167,16 +167,16 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
padding = attrs.get_int_tuple("padding") padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides") strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation") dilation = attrs.get_int_tuple("dilation")
kh, kw = attrs.get_int_tuple('kernel_size')
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs.get_string("layout") layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout") out_layout = attrs.get_string("out_layout")
out_dtype = attrs.get_string("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "not support dilate now" assert dilation == (1, 1), "not support dilate now"
if groups == 1: if groups == 1:
# pylint: disable=assignment-from-no-return # pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw), out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding,
strides, padding, layout, out_layout) layout, out_layout, out_dtype)
# pylint: enable=assignment-from-no-return # pylint: enable=assignment-from-no-return
else: else:
raise ValueError("not support arbitrary group number > 1 for now") raise ValueError("not support arbitrary group number > 1 for now")
...@@ -190,16 +190,9 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): ...@@ -190,16 +190,9 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
def schedule_contrib_conv2d_NCHWc(attrs, outs, target): def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of conv2d NCHWc""" """Schedule definition of conv2d NCHWc"""
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
kh, kw = attrs.get_int_tuple('kernel_size')
oc = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
with tvm.target.create(target): with tvm.target.create(target):
if groups == 1: if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding, return topi.generic.schedule_conv2d_NCHWc(outs)
layout, out_layout, outs)
else: else:
raise ValueError("not support group number > 1 for now") raise ValueError("not support group number > 1 for now")
......
...@@ -60,6 +60,53 @@ class DispatchContext(object): ...@@ -60,6 +60,53 @@ class DispatchContext(object):
ret = self._old_ctx.query(target, workload) ret = self._old_ctx.query(target, workload)
return ret return ret
def update(self, target, workload, cfg):
"""
Update context with a specific config.
Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
cfg : ConfigSpace
The specific configuration.
Note
----
This interface is for cases when TVM decides to replace an operator in the graph.
For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW`
convolution with `NCHW[x]c` implementation on x86 CPUs.
Thus in TOPI, we first query schedule using original `NCHW` workload,
then update the dispatcher with the new `NCHW[x]c` workload.
So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using
its own workload directly.
.. code-block:: python
@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo):
workload = get_conv2d_workload(...)
dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target()
config = dispatch_ctx.query(target, workload)
# Get conv2d_NCHWc workload from config
# new_workload = ...
# new_inputs = ...
# new_attrs = ...
# Store altered operator's config
dispatch_ctx.update(target, new_workload, config)
return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs)
We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc`
share the same schedule parameters.
One can construct a new `ConfigEntity` if this is not the case.
"""
raise NotImplementedError()
def _query_inside(self, target, workload): def _query_inside(self, target, workload):
""" """
Query the context to get the specific config for a template. Query the context to get the specific config for a template.
...@@ -179,6 +226,11 @@ class ApplyConfig(DispatchContext): ...@@ -179,6 +226,11 @@ class ApplyConfig(DispatchContext):
self.workload = workload self.workload = workload
return self._config return self._config
def update(self, target, workload, cfg):
"""Override update"""
self.workload = workload
self._config = cfg
class ApplyHistoryBest(DispatchContext): class ApplyHistoryBest(DispatchContext):
""" """
...@@ -197,6 +249,7 @@ class ApplyHistoryBest(DispatchContext): ...@@ -197,6 +249,7 @@ class ApplyHistoryBest(DispatchContext):
self.best_by_targetkey = {} self.best_by_targetkey = {}
self.best_by_model = {} self.best_by_model = {}
self._best_user_defined = {}
if records: if records:
self.load(records) self.load(records)
...@@ -264,17 +317,32 @@ class ApplyHistoryBest(DispatchContext): ...@@ -264,17 +317,32 @@ class ApplyHistoryBest(DispatchContext):
if opt.startswith("-model"): if opt.startswith("-model"):
model = opt[7:] model = opt[7:]
key = (model, workload) key = (model, workload)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_model: if key in self.best_by_model:
return self.best_by_model[key][0].config return self.best_by_model[key][0].config
# then try matching by target key # then try matching by target key
for k in target.keys: for k in target.keys:
key = (k, workload) key = (k, workload)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_targetkey: if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config return self.best_by_targetkey[key][0].config
return None return None
def update(self, target, workload, cfg):
for opt in target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
self._best_user_defined[key] = cfg
for k in target.keys:
key = (k, workload)
self._best_user_defined[key] = cfg
class FallbackContext(DispatchContext): class FallbackContext(DispatchContext):
""" """
...@@ -324,6 +392,10 @@ class FallbackContext(DispatchContext): ...@@ -324,6 +392,10 @@ class FallbackContext(DispatchContext):
if key in self.memory: if key in self.memory:
del self.memory[key] del self.memory[key]
def update(self, target, workload, cfg):
key = (str(target), workload)
self.memory[key] = cfg
DispatchContext.current = FallbackContext() DispatchContext.current = FallbackContext()
def clear_fallback_cache(target, workload): def clear_fallback_cache(target, workload):
...@@ -391,37 +463,14 @@ class ApplyGraphBest(DispatchContext): ...@@ -391,37 +463,14 @@ class ApplyGraphBest(DispatchContext):
cfg : ConfigSpace cfg : ConfigSpace
The specific configuration. The specific configuration.
""" """
cfg = self._records[self._counter][0].config if self._counter < len(self._records):
self._counter += 1 cfg = self._records[self._counter][0].config
return cfg self._counter += 1
self.update(target, workload, cfg)
def query_global_dict(self, key): return cfg
""" key = (str(target), workload)
Query the context to get config from global
config dictionary.
Parameters
----------
key : str
Key to query the config.
Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
return self._global_cfg_dict[key] return self._global_cfg_dict[key]
def update_global_dict(self, key, val): def update(self, target, workload, cfg):
""" key = (str(target), workload)
Update the global config dictionary. self._global_cfg_dict[key] = cfg
Parameters
----------
key : str
Key of config.
val : ConfigSpace
Value of config.
"""
self._global_cfg_dict[key] = val
# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ # pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ
# pylint: disable=consider-using-enumerate # pylint: disable=consider-using-enumerate,too-many-lines
""" """
Template configuration space. Template configuration space.
...@@ -996,5 +996,17 @@ class FallbackConfigEntity(ConfigSpace): ...@@ -996,5 +996,17 @@ class FallbackConfigEntity(ConfigSpace):
if not isinstance(self.space_map[knob_name], SplitSpace): if not isinstance(self.space_map[knob_name], SplitSpace):
self._entity_map[knob_name] = best_match_cfg[knob_name] self._entity_map[knob_name] = best_match_cfg[knob_name]
def __setitem__(self, name, entity):
"""set the entity(knob) of by name
Parameters
----------
name: str
name of the entity
entity: SplitEntity, ReorderEntity, AnnotateEntity, OtherOptionEntity
value of the entity
"""
self._entity_map[name] = entity
def __repr__(self): def __repr__(self):
return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash) return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
...@@ -182,7 +182,7 @@ def create(func_name, args, target, target_host=None, template_key=None): ...@@ -182,7 +182,7 @@ def create(func_name, args, target, target_host=None, template_key=None):
return ret return ret
def args_to_workload(x): def args_to_workload(x, topi_compute_func=None):
"""Convert argument list to hashable workload tuple. """Convert argument list to hashable workload tuple.
This function will convert list to tuple, tvm node to python value and This function will convert list to tuple, tvm node to python value and
flatten tvm.tensor.Tensor to a tuple flatten tvm.tensor.Tensor to a tuple
...@@ -191,6 +191,8 @@ def args_to_workload(x): ...@@ -191,6 +191,8 @@ def args_to_workload(x):
---------- ----------
x: primitive hashable types or tensor.Tensor x: primitive hashable types or tensor.Tensor
The original value The original value
topi_compute_func: topi compute function
The function name will be added as first element of the workload tuple
Returns Returns
------- -------
...@@ -198,18 +200,19 @@ def args_to_workload(x): ...@@ -198,18 +200,19 @@ def args_to_workload(x):
The hashable value The hashable value
""" """
if isinstance(x, tensor.Tensor): if isinstance(x, tensor.Tensor):
return get_const_tuple(x.shape) + (x.dtype, ) workload = get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)): elif isinstance(x, (tuple, list, container.Array)):
return tuple([args_to_workload(a) for a in x]) workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)): elif isinstance(x, (str, int, float, np.int, np.float)):
return x workload = x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value workload = x.value
elif x is None: elif x is None:
return 0 workload = 0
else: else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use' raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types only' % type(x)) 'primitive types only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload
def template(func): def template(func):
""" """
......
# pylint: disable=unused-variable,invalid-name # pylint: disable=unused-variable,invalid-name,unused-argument
""" """
Decorators for registering tunable templates to TOPI. Decorators for registering tunable templates to TOPI.
...@@ -13,7 +13,6 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. ...@@ -13,7 +13,6 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
from ... import _api_internal, tensor from ... import _api_internal, tensor
from ..util import get_func_name
from .task import args_to_workload, dispatcher from .task import args_to_workload, dispatcher
...@@ -55,8 +54,6 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None): ...@@ -55,8 +54,6 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
-------- --------
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
""" """
fname = get_func_name(topi_compute)
def _decorator(f): def _decorator(f):
targets = [target_keys] if isinstance(target_keys, str) else target_keys targets = [target_keys] if isinstance(target_keys, str) else target_keys
for target_key in targets: for target_key in targets:
...@@ -68,7 +65,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None): ...@@ -68,7 +65,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
def config_dispatcher(*args, **kwargs): def config_dispatcher(*args, **kwargs):
"""override topi call as a config dispatcher""" """override topi call as a config dispatcher"""
assert not kwargs, "Do not support kwargs in template function call" assert not kwargs, "Do not support kwargs in template function call"
return (fname, ) + args_to_workload(args) return args_to_workload(args, topi_compute)
_REGISTED_DISPATHCER[target_key][topi_compute] = config_dispatcher _REGISTED_DISPATHCER[target_key][topi_compute] = config_dispatcher
config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_compute] config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_compute]
...@@ -88,7 +85,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None): ...@@ -88,7 +85,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
attrs = {} attrs = {}
for k, v in node.op.attrs.items(): for k, v in node.op.attrs.items():
attrs[k] = v attrs[k] = v
attrs['workload'] = (fname, ) + args_to_workload(args) attrs['workload'] = args_to_workload(args, topi_compute)
if isinstance(op, tensor.ComputeOp): if isinstance(op, tensor.ComputeOp):
op = _api_internal._ComputeOp( op = _api_internal._ComputeOp(
op.name, op.tag, attrs, op.axis, op.body) op.name, op.tag, attrs, op.axis, op.body)
...@@ -153,7 +150,7 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None) ...@@ -153,7 +150,7 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None)
if topi_schedule not in _REGISTED_DISPATHCER[target_key]: if topi_schedule not in _REGISTED_DISPATHCER[target_key]:
@topi_schedule.register(target_key) @topi_schedule.register(target_key)
@dispatcher @dispatcher
def config_dispatcher(outs): def config_dispatcher(outs, *args, **kwargs):
"""override topi call as a workload dispatcher""" """override topi call as a workload dispatcher"""
def traverse(tensors): def traverse(tensors):
"""traverse all ops to find attached workload""" """traverse all ops to find attached workload"""
...@@ -179,11 +176,11 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None) ...@@ -179,11 +176,11 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None)
config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_schedule] config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_schedule]
@config_dispatcher.register(template_keys) @config_dispatcher.register(template_keys)
def template_call(cfg, outs): def template_call(cfg, outs, *args, **kwargs):
"""call the schedule func""" """call the schedule func"""
if f == topi_schedule.fdefault: if f == topi_schedule.fdefault:
return f(outs) return f(outs, *args, **kwargs)
return f(cfg, outs) return f(cfg, outs, *args, **kwargs)
return f return f
......
...@@ -55,33 +55,15 @@ def schedule_conv2d_nhwc(outs): ...@@ -55,33 +55,15 @@ def schedule_conv2d_nhwc(outs):
@tvm.target.generic_func @tvm.target.generic_func
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, def schedule_conv2d_NCHWc(outs):
padding, layout, out_layout, outs):
"""Schedule for conv2d_NCHW[x]c """Schedule for conv2d_NCHW[x]c
Parameters Parameters
---------- ----------
num_filter : int
The number of filter, i.e., the output channel.
kernel_size : tuple of int
(kernel_height, kernel_width)
strides : tuple of int
(stride_of_height, stride_of_width)
padding : tuple of int
(pad_of_height, pad_of_width)
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor outs : Array of Tensor
The computation graph description of conv2d_NCHWc The computation graph description of conv2d_NCHWc
in the format of an array of tensors. in the format of an array of tensors.
The number of filter, i.e., the output channel.
Returns Returns
------- -------
......
...@@ -73,30 +73,11 @@ def schedule_conv2d_nhwc(outs): ...@@ -73,30 +73,11 @@ def schedule_conv2d_nhwc(outs):
@generic.schedule_conv2d_NCHWc.register(["hls"]) @generic.schedule_conv2d_NCHWc.register(["hls"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, def schedule_conv2d_NCHWc(outs):
padding, layout, out_layout, outs):
"""Schedule for conv2d_NCHW[x]c """Schedule for conv2d_NCHW[x]c
Parameters Parameters
---------- ----------
num_filter : int
The number of filter, i.e., the output channel.
kernel_size : tuple of int
(kernel_height, kernel_width)
strides : tuple of int
(stride_of_height, stride_of_width)
padding : tuple of int
(pad_of_height, pad_of_width)
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor outs : Array of Tensor
The computation graph description of conv2d_NCHWc The computation graph description of conv2d_NCHWc
in the format of an array of tensors. in the format of an array of tensors.
......
...@@ -61,8 +61,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -61,8 +61,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
@conv2d_NCHWc.register(["intel_graphics"]) @conv2d_NCHWc.register(["intel_graphics"])
def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout,\ def _decl_conv2d(data, kernel, stride, padding, layout, out_layout, out_dtype='float32'):
out_layout, out_dtype='float32'):
"""Conv2D operator for Intel Graphics backend. """Conv2D operator for Intel Graphics backend.
Parameters Parameters
...@@ -101,7 +100,7 @@ def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout, ...@@ -101,7 +100,7 @@ def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout,
return _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype) return _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype)
@generic.schedule_conv2d_NCHWc.register(["intel_graphics"]) @generic.schedule_conv2d_NCHWc.register(["intel_graphics"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_layout, outs): def schedule_conv2d_NCHWc(outs):
"""Schedule for conv2d_nchw for Intel Graphics """Schedule for conv2d_nchw for Intel Graphics
Parameters Parameters
......
...@@ -84,32 +84,6 @@ def _get_workload(data, kernel, stride, padding, out_dtype): ...@@ -84,32 +84,6 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
'{} vs. {}".format(data.dtype, kernel.dtype) '{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _get_workload_int8(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func
def _get_alter_layout_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule for conv2d_alter_layout. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
@tvm.target.generic_func @tvm.target.generic_func
def _get_schedule(wkl): def _get_schedule(wkl):
...@@ -122,28 +96,6 @@ def _get_schedule(wkl): ...@@ -122,28 +96,6 @@ def _get_schedule(wkl):
return wkl return wkl
@tvm.target.generic_func
def _get_schedule_NCHWc(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
@tvm.target.generic_func
def _get_schedule_NCHWc_int8(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None): def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
"""Convolution operator in NCHW layout. """Convolution operator in NCHW layout.
...@@ -302,8 +254,7 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'): ...@@ -302,8 +254,7 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
@tvm.target.generic_func @tvm.target.generic_func
def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, def conv2d_NCHWc(data, kernel, stride, padding, layout, out_layout, out_dtype='float32'):
padding, layout, out_layout, out_dtype='float32'):
"""Conv2D operator for nChw[x]c layout. """Conv2D operator for nChw[x]c layout.
Parameters Parameters
...@@ -316,12 +267,6 @@ def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, ...@@ -316,12 +267,6 @@ def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride,
[num_filter_chunk, in_channel_chunk, filter_height, filter_width, [num_filter_chunk, in_channel_chunk, filter_height, filter_width,
in_channel_block, num_filter_block] in_channel_block, num_filter_block]
num_filter : int
number of filters, i.e., output channel size
kernel_size : tuple of two ints
[kernel_height, kernel_width]
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
......
# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name # pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
"""1x1 Conv2D schedule on for Intel CPU""" """1x1 Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm import tvm
from tvm.autotvm.task import ConfigEntity from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
import topi
from ..nn.util import infer_pad from ..nn.util import infer_pad
from ..nn.pad import pad from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake from .check_targets import check_skylake
AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor']) def _fallback_schedule(cfg, wkl, simd_width):
def _get_default_schedule(wkl, simd_width):
HPAD, WPAD = wkl.hpad, wkl.wpad HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
...@@ -37,45 +31,11 @@ def _get_default_schedule(wkl, simd_width): ...@@ -37,45 +31,11 @@ def _get_default_schedule(wkl, simd_width):
if out_width % ow_factor == 0: if out_width % ow_factor == 0:
for oh_factor in range(out_height, 0, -1): for oh_factor in range(out_height, 0, -1):
if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
return AVXConv1x1Fwd(ic_bn, oc_bn, oh_factor, ow_factor) cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) cfg["tile_oh"] = OtherOptionEntity(oh_factor)
cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor])
return
def _fallback_schedule(wkl, simd_width):
batch_size, in_channel, height, width, _ = wkl[1]
out_channel, _, hkernel, wkernel, _ = wkl[2]
HPAD, WPAD = wkl[4]
HSTR, WSTR = wkl[3]
out_height = (height + 2 * HPAD - hkernel) // HSTR + 1
out_width = (width + 2 * WPAD - wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if out_channel % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if in_channel % bn == 0:
ic_bn = bn
break
for ow_factor in range(out_width, 0, -1):
if out_width % ow_factor == 0:
for oh_factor in range(out_height, 0, -1):
if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [in_channel // ic_bn, ic_bn]],
["tile_oc", "sp", [out_channel // oc_bn, oc_bn]],
["tile_oh", "ot", oh_factor],
["tile_ow", "sp", [out_width // ow_factor,
ow_factor]],],
"t": ""}
return ConfigEntity.from_json_dict(cfg_dict)
raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
...@@ -148,8 +108,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu ...@@ -148,8 +108,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu
def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
# fetch schedule # fetch schedule
ic_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oh"].val, oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
cfg["tile_ow"].size[-1]) _, _, _, _, ic_bn = get_const_tuple(data.shape)
# schedule data # schedule data
A = data A = data
...@@ -201,57 +161,13 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): ...@@ -201,57 +161,13 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
return s return s
def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel): def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
""" Declaration for int8 conv"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
# Intel performs dot product of 2 "4" Int8 values
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
# Reshaping kernel as the last 2 dimensions are 1x1 (k_h x k_w)
k_shape = kernel.shape
kernel = topi.reshape(kernel, (k_shape[0], k_shape[1], k_shape[2], k_shape[3],
k_shape[4] * k_shape[5] * k_shape[6]))
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR, ow*WSTR,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")
return conv
def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
""" """
Defines the schedule for INT8 for intel machines Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations Uses the Intel intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/ More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training lower-numerical-precision-deep-learning-inference-and-training
""" """
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
int32_lanes = -1 int32_lanes = -1
if check_skylake(target): if check_skylake(target):
...@@ -260,6 +176,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -260,6 +176,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
return s return s
assert int32_lanes != -1 assert int32_lanes != -1
oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
_, _, _, _, ic_bn = get_const_tuple(data.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
# schedule data # schedule data
A = data A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp): if isinstance(s[A].op, tvm.tensor.ComputeOp):
...@@ -271,8 +191,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -271,8 +191,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
CC = s.cache_write(C, 'global') CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor) ow_outer, ow_inner = s[C].split(ow, factor=ow_factor)
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block) s[C].vectorize(oc_block)
...@@ -282,17 +202,17 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -282,17 +202,17 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
s[C].parallel(parallel_axis) s[C].parallel(parallel_axis)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
# Skylake and future processors have 16 vector lanes # Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0 assert oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_outer, ic_f_inner, oh_inner, s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner,
ow_inner, oc_f_inner, oc_s_inner, ic_s_inner) ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].fuse(oc_chunk, oh_outer) s[CC].fuse(oc_chunk, oh_outer)
...@@ -303,8 +223,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -303,8 +223,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
if C != O: if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer) parallel_axis = s[O].fuse(oc_chunk, oh_outer)
......
# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name # pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
"""Conv2D schedule on for Intel CPU""" """Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm import tvm
from tvm.autotvm.task import ConfigEntity from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.util import infer_pad from ..nn.util import infer_pad
from ..nn.pad import pad from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake from .check_targets import check_skylake
AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw']) def _fallback_schedule(cfg, wkl, simd_width):
def _get_default_schedule(wkl, simd_width):
HPAD, WPAD = wkl.hpad, wkl.wpad HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
...@@ -36,42 +32,10 @@ def _get_default_schedule(wkl, simd_width): ...@@ -36,42 +32,10 @@ def _get_default_schedule(wkl, simd_width):
reg_n = n reg_n = n
break break
return AVXConvCommonFwd(ic_bn, oc_bn, reg_n, False) cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
def _fallback_schedule(wkl, simd_width): cfg["unroll_kw"] = OtherOptionEntity(False)
batch_size, in_channel, height, width, _ = wkl[1]
out_channel, _, hkernel, wkernel, _ = wkl[2]
HPAD, WPAD = wkl[4]
HSTR, WSTR = wkl[3]
out_width = (width + 2 * WPAD - wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if out_channel % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if in_channel % bn == 0:
ic_bn = bn
break
reg_n = 1
for n in range(31, 0, -1):
if out_width % n == 0:
reg_n = n
break
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [in_channel // ic_bn, ic_bn]],
["tile_oc", "sp", [out_channel // oc_bn, oc_bn]],
["tile_ow", "sp", [out_width // reg_n, reg_n]],
["unroll_kw", "ot", False]],
"t": ""}
return ConfigEntity.from_json_dict(cfg_dict)
def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
...@@ -147,8 +111,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu ...@@ -147,8 +111,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu
def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
# fetch schedule # fetch schedule
ic_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_ow"].size[-1], reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
cfg["unroll_kw"].val) _, _, _, _, ic_bn = get_const_tuple(data.shape)
# schedule data # schedule data
A = data A = data
...@@ -197,52 +161,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): ...@@ -197,52 +161,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
return s return s
def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel): def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
"""
This function sets up the compute for INT8 conv 2d
Inputs are in INT8 datatype
Output is in INT32 datatype
"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
# pack data
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
# convolution
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
kh = tvm.reduce_axis((0, wkl.hkernel), name='kh')
kw = tvm.reduce_axis((0, wkl.wkernel), name='kw')
# Intel performs dot product of 2 "4" Int8 values
# Current implementation requires ic_bn to be a multiple of 4
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")
return conv
def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
""" """
Defines the schedule for INT8 for intel machines Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations Uses the Intel intrinsics to use INT8 operations
...@@ -263,6 +182,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -263,6 +182,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
return s return s
assert int32_lanes != -1 assert int32_lanes != -1
reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
_, _, _, _, ic_bn = get_const_tuple(data.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
A = data A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp): if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, _ = s[A].op.axis batch, ic_chunk, ih, iw, _ = s[A].op.axis
...@@ -274,7 +197,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -274,7 +197,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
CC = s.cache_write(C, 'global') CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis _, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n) ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh) parallel_axis = s[C].fuse(oc_chunk, oh)
s[C].vectorize(oc_block) s[C].vectorize(oc_block)
...@@ -285,14 +208,14 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -285,14 +208,14 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n) ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
# Skylake and future processors have 16 vector lanes # Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0 assert oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
if sch.unroll_kw: if unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw, s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner) ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].unroll(kw) s[CC].unroll(kw)
...@@ -308,7 +231,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): ...@@ -308,7 +231,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
if C != O: if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh) parallel_axis = s[O].fuse(oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis) s[C].compute_at(s[O], parallel_axis)
......
...@@ -54,19 +54,11 @@ def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad, ...@@ -54,19 +54,11 @@ def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad,
data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES) data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
if out_dtype == 'int32': if out_dtype == 'int32':
if k_h != 1: kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
else:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES//4,
NUM_VEC_LANES, 4, k_h, k_w)
elif out_dtype == 'float32': elif out_dtype == 'float32':
if k_h != 1: kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, NUM_VEC_LANES, NUM_VEC_LANES)
NUM_VEC_LANES, NUM_VEC_LANES)
else:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES,
NUM_VEC_LANES, k_h, k_w)
out_height = (im_height + 2 * hpad - k_h) // hstride + 1 out_height = (im_height + 2 * hpad - k_h) // hstride + 1
out_width = (im_width + 2 * wpad - k_w) // wstride + 1 out_width = (im_width + 2 * wpad - k_w) // wstride + 1
o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES) o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
...@@ -103,8 +95,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f ...@@ -103,8 +95,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
with tvm.target.create(TARGET_NAME): with tvm.target.create(TARGET_NAME):
conv = topi.nn.conv2d_NCHWc(data, kernel, num_filter=out_filter, conv = topi.nn.conv2d_NCHWc(data, kernel, stride=hstride,
kernel_size=(k_h, k_w), stride=hstride,
padding=hpad, layout='NCHWc', padding=hpad, layout='NCHWc',
out_layout='NCHWc', out_dtype=out_dtype) out_layout='NCHWc', out_dtype=out_dtype)
out = topi.nn.relu(conv) out = topi.nn.relu(conv)
...@@ -114,13 +105,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f ...@@ -114,13 +105,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True)) LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
# Generate and run the optimized schedule # Generate and run the optimized schedule
sconv = topi.generic.nn.schedule_conv2d_NCHWc(num_filter=out_filter, sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out])
kernel_size=(k_h, k_w),
strides=hstride,
padding=hpad,
layout='NCHWc',
out_layout='NCHWc',
outs=[out])
func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv') func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv')
func(data_array, kernel_array, c_sch) func(data_array, kernel_array, c_sch)
......
"""Test for NCHW[x]c convolution"""
import numpy as np
import tvm
from tvm import autotvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend
def _transform_data(data, bn):
# NCHW -> NCHW[x]c
batch_size, channel, height, width = data.shape
data = np.transpose(data, (0, 2, 3, 1))
data = np.reshape(data, (batch_size, height, width, channel//bn, bn))
data = np.transpose(data, (0, 3, 1, 2, 4))
return data
def _transform_kernel(kernel, ic_bn, oc_bn):
# OIHW -> OIHW[x]i[x]o
out_channel, in_channel, kh, kw = kernel.shape
kernel = np.transpose(kernel, (1, 2, 3, 0))
kernel = np.reshape(kernel, (in_channel, kh, kw, out_channel//oc_bn, oc_bn))
kernel = np.transpose(kernel, (1, 2, 3, 4, 0))
kernel = np.reshape(kernel, (kh, kw, out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn))
kernel = np.transpose(kernel, (2, 4, 0, 1, 5, 3))
return kernel
def _transform_bias(bias, bn):
# [num_filter, 1, 1] -> [num_filter//bn, 1, 1, bn]
num_filter, h, w = bias.shape
bias = np.transpose(bias, (1, 2, 0))
bias = np.reshape(bias, (h, w, num_filter//bn, bn))
bias = np.transpose(bias, (2, 0, 1, 3))
return bias
def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"):
assert dilation == 1, "conv2d_NCHWc does not support dilation for now."
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding))
in_height = in_width = in_size
# for testing functionality,
# we choose arbitrary block size that can divide the channel,
# regardless of the performance.
oc_block = 1
for bn in range(16, 0, -1):
if num_filter % bn == 0:
oc_block = bn
break
ic_block = 1
for bn in range(oc_block, 0, -1):
if in_channel % bn == 0:
ic_block = bn
break
A = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A')
W = tvm.placeholder((num_filter//oc_block, in_channel//ic_block, kernel, kernel, ic_block, oc_block), name='W')
bias = tvm.placeholder((num_filter//oc_block, 1, 1, oc_block), name='bias')
@memoize("topi.tests.test_topi_conv2d_NCHWc.verify_conv2d_NCHWc")
def get_ref_data():
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
w_np = np.random.uniform(size=(num_filter, in_channel, kernel, kernel)).astype(dtype)
b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype)
c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
if add_bias:
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \
_transform_bias(b_np, oc_block), _transform_data(c_np, oc_block)
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding),
layout='NCHW%dc'%ic_block,
out_layout="NCHW%dc"%oc_block,
out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_conv2d_NCHWc([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
# test llvm only for now since conv2d_NCHWc implement is missing in other backend.
for device in ["llvm"]:
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
if __name__ == "__main__":
# ResNet18 workloads
verify_conv2d_NCHWc(1, 3, 224, 64, 7, 2, 3)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc(1, 64, 56, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 64, 56, 128, 3, 2, 1)
verify_conv2d_NCHWc(1, 64, 56, 128, 1, 2, 0)
verify_conv2d_NCHWc(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_NCHWc(1, 128, 28, 256, 3, 2, 1)
verify_conv2d_NCHWc(1, 128, 28, 256, 1, 2, 0)
verify_conv2d_NCHWc(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_NCHWc(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_NCHWc(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_NCHWc(1, 512, 7, 512, 3, 1, 1)
# bias, relu
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_relu=True)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
# disable dilation test since it is not supported by NCHW[x]c conv for now.
# verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, dilation=2)
# batch size
verify_conv2d_NCHWc(4, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc(9, 64, 56, 64, 3, 1, 1)
# weird workloads
verify_conv2d_NCHWc(2, 2, 2, 2, 2, 2, 2)
verify_conv2d_NCHWc(3, 3, 3, 3, 3, 3, 3)
verify_conv2d_NCHWc(4, 4, 4, 4, 4, 4, 4)
verify_conv2d_NCHWc(5, 5, 5, 5, 5, 5, 5)
verify_conv2d_NCHWc(6, 6, 6, 6, 6, 6, 6)
# disable these tests due to some bugs of llvm with nvptx
# verify_conv2d_NCHWc(1, 1, 1, 1, 1, 1, 1, dilation=1)
# verify_conv2d_NCHWc(1, 1, 1, 1, 1, 1, 1, dilation=2)
# verify_conv2d_NCHWc(2, 13, 71, 59, 3, 1, 1)
# inception v3 workloads
verify_conv2d_NCHWc(1, 3, 299, 32, 3, 2, 0)
verify_conv2d_NCHWc(1, 32, 149, 32, 3, 1, 0)
verify_conv2d_NCHWc(1, 32, 147, 64, 3, 1, 1)
verify_conv2d_NCHWc(1, 64, 73, 80, 1, 1, 0)
verify_conv2d_NCHWc(1, 80, 73, 192, 3, 1, 0)
verify_conv2d_NCHWc(1, 192, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 192, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc(1, 48, 35, 64, 5, 1, 2)
verify_conv2d_NCHWc(1, 64, 35, 96, 3, 1, 1)
verify_conv2d_NCHWc(1, 96, 35, 96, 3, 1, 1)
verify_conv2d_NCHWc(1, 192, 35, 32, 1, 1, 0)
verify_conv2d_NCHWc(1, 256, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 256, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc(1, 288, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc(1, 288, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc(1, 288, 35, 384, 3, 2, 0)
verify_conv2d_NCHWc(1, 96, 35, 96, 3, 2, 0)
verify_conv2d_NCHWc(1, 768, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 768, 17, 128, 1, 1, 0)
verify_conv2d_NCHWc(1, 128, 17, 128, 1, 1, 0)
verify_conv2d_NCHWc(1, 128, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc(1, 128, 17, 128, 7, 1, 3)
verify_conv2d_NCHWc(1, 128, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 768, 17, 160, 1, 1, 0)
verify_conv2d_NCHWc(1, 160, 17, 160, 1, 1, 0)
verify_conv2d_NCHWc(1, 160, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc(1, 160, 17, 160, 7, 1, 3)
verify_conv2d_NCHWc(1, 160, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 192, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 192, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc(1, 192, 17, 320, 3, 2, 0)
verify_conv2d_NCHWc(1, 192, 17, 192, 3, 2, 0)
verify_conv2d_NCHWc(1, 1280, 8, 320, 1, 1, 0)
verify_conv2d_NCHWc(1, 1280, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc(1, 384, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc(1, 384, 8, 384, 3, 1, 1)
verify_conv2d_NCHWc(1, 1280, 8, 448, 1, 1, 0)
verify_conv2d_NCHWc(1, 448, 8, 384, 3, 1, 1)
verify_conv2d_NCHWc(1, 1280, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 320, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 448, 1, 1, 0)
verify_conv2d_NCHWc(1, 2048, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc(1, 1024, 19, 84, 3, 1, 1)
verify_conv2d_NCHWc(1, 2048, 10, 126, 3, 1, 1)
verify_conv2d_NCHWc(1, 512, 5, 126, 3, 1, 1)
verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1)
...@@ -14,7 +14,6 @@ import nnvm.compiler ...@@ -14,7 +14,6 @@ import nnvm.compiler
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from topi.x86.conv2d import conv_NCHWc_arg_to_workload
import tvm.contrib.graph_runtime as runtime import tvm.contrib.graph_runtime as runtime
################################################################# #################################################################
...@@ -118,17 +117,9 @@ def tune_kernels(tasks, ...@@ -118,17 +117,9 @@ def tune_kernels(tasks,
prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# converting conv2d tasks to conv2d_NCHWc tasks # converting conv2d tasks to conv2d_NCHWc tasks
# data, kernel are tuples of ("TENSOR", shape, dtype) task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=tsk.args,
data, kernel, strides, padding, layout, dtype = tsk.args target=target, template_key='direct')
kernel_size = (kernel[1][2], kernel[1][3]) task.workload = tsk.workload
data_plc = tvm.placeholder(data[1], name="data")
kernel_plc = tvm.placeholder(kernel[1], name="kernel")
args = [data_plc, kernel_plc, kernel[1][0], kernel_size, strides,
padding, layout, layout, dtype]
args = autotvm.task.nnvm_integration.serialize_args(args)
task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=args, target=target)
task.workload = conv_NCHWc_arg_to_workload(data_plc, kernel_plc, kernel_size,
strides, padding, layout, layout, dtype)
# create tuner # create tuner
if tuner == 'xgb' or tuner == 'xgb-rank': if tuner == 'xgb' or tuner == 'xgb-rank':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment