Commit 5af51280 by Yizhi Liu Committed by Tianqi Chen

[TOPI] parallel schedule improve for x86 & layout_transform support (#1130)

* add layout_transform. add schedule support for ndim>=5 for x86

* fix lint
parent 0cbec2b2
......@@ -482,5 +482,27 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
return tvm::compute(output_shape, l, name, tag);
}
using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indices)>;
/*!
* \brief Transform the layout according to the mapping function \p to_src_indices.
* \param src the source input.
* \param dst_shape the output shape.
* \param to_src_indices the mapping function from input index to output index.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape \p dst_shape.
*/
inline Tensor layout_transform(const Tensor& src,
const Array<Expr>& dst_shape,
const FLayoutIndicesTransform& to_src_indices,
const std::string name = "layout_transform",
const std::string tag = kInjective) {
auto src_shape = src->shape;
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
return src(to_src_indices(dst_indices));
}, name, tag);
}
} // namespace topi
#endif // TOPI_NN_H_
# pylint: disable=invalid-name,unused-argument
"""Generic nn operators"""
from __future__ import absolute_import as _abs
import tvm
......@@ -54,6 +55,32 @@ def schedule_conv2d_nhwc(outs):
@tvm.target.generic_func
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, padding, outs):
"""Schedule for conv2d_NCHW[x]c
Parameters
----------
num_filter: int
The number of filter, i.e., the output channel.
kernel_size: tuple of int
(kernel_height, kernel_width)
strides: tuple of int
(stride_of_height, stride_of_width)
padding: tuple of int
(pad_of_height, pad_of_width)
outs: Array of Tensor
The computation graph description of conv2d_NCHWc
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs):
"""Schedule for conv2d_transpose_nchw
......
......@@ -7,3 +7,4 @@ from .binarize_pack import schedule_binarize_pack
from .binary_dense import schedule_binary_dense
from .nn import *
from .injective import *
from .pooling import schedule_pool, schedule_global_pool
......@@ -23,13 +23,11 @@ def schedule_injective(outs):
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) == 4:
n, c, _, _ = s[x].op.axis
fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h
if len(s[x].op.axis) >= 5:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
s[x].parallel(fused)
elif len(s[x].op.axis) == 5:
n, C, h, _, _ = s[x].op.axis
fused = s[x].fuse(n, C, h)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)
else:
s[x].parallel(s[x].op.axis[0])
......
......@@ -5,23 +5,6 @@ import tvm
from .. import generic
from .. import tag
def _default_schedule(outs, auto_inline):
"""Default schedule for x86."""
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
if auto_inline:
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
if len(s[x].op.axis) == 4:
n, c, _, _ = s[x].op.axis
fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h
s[x].parallel(fused)
else:
s[x].parallel(s[x].op.axis[0])
return s
@generic.schedule_softmax.register(["cpu"])
def schedule_softmax(outs):
"""Schedule for softmax
......@@ -37,25 +20,19 @@ def schedule_softmax(outs):
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@generic.schedule_pool.register(["cpu"])
def schedule_pool(outs):
"""Schedule for pool
Parameters
----------
outs: Array of Tensor
The computation graph description of pool
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) >= 5:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)
else:
s[x].parallel(s[x].op.axis[0])
return s
@generic.schedule_dense.register(["cpu"])
......
# pylint: disable=invalid-name, unused-variable
"""Schedule for pooling operators"""
import tvm
from .. import generic
from .. import tag
def _parallel_sch(sch):
if len(sch.op.axis) >= 5:
fused = sch.fuse(sch.op.axis[0], sch.op.axis[1], sch.op.axis[2])
sch.parallel(fused)
elif len(sch.op.axis) >= 3:
n, c, _, _ = sch.op.axis
fused = sch.fuse(n, c) # for nhwc layout, fuse n and h
sch.parallel(fused)
else:
sch.parallel(sch.op.axis[0])
@generic.schedule_pool.register(["cpu"])
def schedule_pool(outs):
"""Schedule for pool
Parameters
----------
outs: Array of Tensor
The computation graph description of pool
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Pool):
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
s[PaddedInput].compute_inline()
_parallel_sch(s[Pool])
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
PaddedInput = OP.input_tensors[0]
Pool = OP.output(0)
_schedule(PaddedInput, Pool)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
return s
@generic.schedule_global_pool.register(["cpu"])
def schedule_global_pool(outs):
"""Schedule for global pool
Parameters
----------
outs: Array of Tensor
The computation graph description of pool
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('global_pool'):
Pool = OP.output(0)
_parallel_sch(s[Pool])
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment