Commit c6b1020b by Yizhi Liu Committed by Tianqi Chen

Generalize pooling to support arbitrary layout (#1103)

* generalize pool2d to arbitrary layout

* explain more the layout support for pool

* allow missing factor size for pooling

* explain what factor size is used for

* fix typo

* name idx -> axis
parent 154104b3
...@@ -33,7 +33,9 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) { ...@@ -33,7 +33,9 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
auto s = create_schedule(out_ops); auto s = create_schedule(out_ops);
auto _schedule = [&](const Tensor& padded_input, const Tensor& pool) { auto _schedule = [&](const Tensor& padded_input, const Tensor& pool) {
if (padded_input->op->is_type<ComputeOpNode>()) {
s[padded_input].compute_inline(); s[padded_input].compute_inline();
}
auto num_thread = target->max_num_threads; auto num_thread = target->max_num_threads;
Tensor out; Tensor out;
Tensor OL; Tensor OL;
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TOPI_NN_POOLING_H_ #define TOPI_NN_POOLING_H_
#include <string> #include <string>
#include <vector>
#include "tvm/tvm.h" #include "tvm/tvm.h"
#include "tvm/ir_pass.h" #include "tvm/ir_pass.h"
...@@ -25,25 +26,28 @@ enum PoolType : int { ...@@ -25,25 +26,28 @@ enum PoolType : int {
}; };
/*! /*!
* \brief Perform pooling on data in NCHW order * \brief Perform pooling on height and width dimension of data.
* *
* \param x The input tensor in NCHW order * \param x The input tensor
* \param kernel_size Vector of two ints: {kernel_height, kernel_width} * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
* \param stride_size Vector of two ints: {stride_height, stride_width} * \param stride_size Vector of two ints: {stride_height, stride_width}
* \param padding_size Vector of two ints: {padding_height, padding_width} * \param padding_size Vector of two ints: {padding_height, padding_width}
* \param pool_type The type of pooling operator * \param pool_type The type of pooling operator
* \param ceil_mode Whether to use ceil when calculating the output size * \param ceil_mode Whether to use ceil when calculating the output size
* \param height_axis index of the height dimension
* \param width_axis index of the width dimension
* *
* \return The output tensor in NCHW order * \return The output tensor in same layout order
*/ */
inline Tensor pool_impl(const Tensor& x,
inline Tensor pool_nchw(const Tensor& x,
const Array<Expr>& kernel_size, const Array<Expr>& kernel_size,
const Array<Expr>& stride_size, const Array<Expr>& stride_size,
const Array<Expr>& padding_size, const Array<Expr>& padding_size,
PoolType pool_type, PoolType pool_type,
bool ceil_mode) { bool ceil_mode,
CHECK_EQ(x->shape.size(), 4) << "Pooling input must be 4-D"; const size_t height_axis,
const size_t width_axis) {
CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 2) << "Pooling padding_size must have 2 elements"; CHECK_EQ(padding_size.size(), 2) << "Pooling padding_size must have 2 elements";
...@@ -55,10 +59,8 @@ inline Tensor pool_nchw(const Tensor& x, ...@@ -55,10 +59,8 @@ inline Tensor pool_nchw(const Tensor& x,
auto padding_height = padding_size[0]; auto padding_height = padding_size[0];
auto padding_width = padding_size[1]; auto padding_width = padding_size[1];
auto batch = x->shape[0]; auto height = x->shape[height_axis];
auto channel = x->shape[1]; auto width = x->shape[width_axis];
auto height = x->shape[2];
auto width = x->shape[3];
auto pad_tuple = detail::GetPadTuple(padding_height, padding_width); auto pad_tuple = detail::GetPadTuple(padding_height, padding_width);
auto pad_top = pad_tuple[0]; auto pad_top = pad_tuple[0];
...@@ -73,8 +75,13 @@ inline Tensor pool_nchw(const Tensor& x, ...@@ -73,8 +75,13 @@ inline Tensor pool_nchw(const Tensor& x,
pad_right += stride_width - 1; pad_right += stride_width - 1;
} }
Array<Expr> pad_before{ 0, 0, pad_top, pad_left }; Array<Expr> pad_before(std::vector<Expr>(x->shape.size(), 0));
Array<Expr> pad_after{ 0, 0, pad_down, pad_right }; pad_before.Set(height_axis, pad_top);
pad_before.Set(width_axis, pad_left);
Array<Expr> pad_after(std::vector<Expr>(x->shape.size(), 0));
pad_after.Set(height_axis, pad_down);
pad_after.Set(width_axis, pad_right);
auto out_height = tvm::ir::Simplify( auto out_height = tvm::ir::Simplify(
(height - kernel_height + pad_top + pad_down) / stride_height + 1); (height - kernel_height + pad_top + pad_down) / stride_height + 1);
...@@ -84,28 +91,36 @@ inline Tensor pool_nchw(const Tensor& x, ...@@ -84,28 +91,36 @@ inline Tensor pool_nchw(const Tensor& x,
auto dheight = tvm::reduce_axis(Range(0, kernel_height)); auto dheight = tvm::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::reduce_axis(Range(0, kernel_width)); auto dwidth = tvm::reduce_axis(Range(0, kernel_width));
Array<Expr> out_shape = x->shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
const int64_t *padding_h = HalideIR::Internal::as_const_int(padding_height);
const int64_t *padding_w = HalideIR::Internal::as_const_int(padding_width);
const bool do_pad = ((padding_h && *padding_h) || (padding_w && *padding_w));
if (pool_type == kMaxPool) { if (pool_type == kMaxPool) {
auto temp = pad(x, pad_before, pad_after, x->dtype.min(), "pad_temp"); auto temp = do_pad ? pad(x, pad_before, pad_after, x->dtype.min(), "pad_temp") : x;
return tvm::compute( return tvm::compute(out_shape, [&](const Array<Var>& output) {
{ batch, channel, out_height, out_width }, Array<Expr> indices;
[&](Var n, Var c, Var h, Var w) { for (const Var& var : output) indices.push_back(var);
return tvm::max(temp(n, c, h * stride_height + dheight, w * stride_width + dwidth), indices.Set(height_axis, output[height_axis] * stride_height + dheight);
{ dheight, dwidth }); indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
return tvm::max(temp(indices), { dheight, dwidth });
}, "tensor", "pool_max"); }, "tensor", "pool_max");
} else if (pool_type == kAvgPool) { } else if (pool_type == kAvgPool) {
auto temp = pad(x, pad_before, pad_after, 0, "pad_temp"); auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
auto tsum = tvm::compute(out_shape, [&](const Array<Var>& output) {
auto tsum = tvm::compute( Array<Expr> indices;
{ batch, channel, out_height, out_width }, for (const Var& var : output) indices.push_back(var);
[&](Var n, Var c, Var h, Var w) { indices.Set(height_axis, output[height_axis] * stride_height + dheight);
return tvm::sum(temp(n, c, h * stride_height + dheight, w * stride_width + dwidth), indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
{ dheight, dwidth }); return tvm::sum(temp(indices), { dheight, dwidth });
}, "tensor", "pool_avg"); }, "tensor", "pool_avg");
return tvm::compute( return tvm::compute(out_shape,
{ batch, channel, out_height, out_width }, [&](const Array<Var>& output) {
[&](Var n, Var c, Var h, Var w) { return tsum(output) / (kernel_height * kernel_width);
return tsum(n, c, h, w) / (kernel_height * kernel_width);
}, "tensor", kElementWise); }, "tensor", kElementWise);
} else { } else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type; LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
...@@ -113,109 +128,57 @@ inline Tensor pool_nchw(const Tensor& x, ...@@ -113,109 +128,57 @@ inline Tensor pool_nchw(const Tensor& x,
} }
} }
/*! inline bool find_height_width(const std::string& layout,
* \brief Perform pooling on data in NHWC order int* height_axis,
* int* width_axis) {
* \param x The input tensor in NHWC order *height_axis = -1, *width_axis = -1;
* \param kernel_size Vector of two ints: {kernel_height, kernel_width} int curr_idx = 0;
* \param stride_size Vector of two ints: {stride_height, stride_width} for (size_t i = 0; i < layout.size(); ++i) {
* \param padding_size Vector of two ints: {padding_height, padding_width} if ((layout[i] >= 'A' && layout[i] <= 'Z') ||
* \param pool_type The type of pooling operator (layout[i] >= 'a' && layout[i] <= 'z')) {
* \param ceil_mode Whether to use ceil when calculating the output size if (layout[i] == 'H') {
* if (*height_axis != -1) return false;
* \return The output tensor in NCHW order *height_axis = curr_idx;
*/ } else if (layout[i] == 'W') {
if (*width_axis != -1) return false;
inline Tensor pool_nhwc(const Tensor& x, *width_axis = curr_idx;
const Array<Expr>& kernel_size, } else if (layout[i] == 'h' || layout[i] == 'w') {
const Array<Expr>& stride_size, // do not support split on height or width, e.g., NCHW16w
const Array<Expr>& padding_size, return false;
PoolType pool_type,
bool ceil_mode) {
CHECK_EQ(x->shape.size(), 4) << "Pooling input must be 4-D";
CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 2) << "Pooling padding_size must have 2 elements";
auto kernel_height = kernel_size[0];
auto kernel_width = kernel_size[1];
auto stride_height = stride_size[0];
auto stride_width = stride_size[1];
auto padding_height = padding_size[0];
auto padding_width = padding_size[1];
auto batch = x->shape[0];
auto height = x->shape[1];
auto width = x->shape[2];
auto channel = x->shape[3];
auto pad_tuple = detail::GetPadTuple(padding_height, padding_width);
auto pad_top = pad_tuple[0];
auto pad_left = pad_tuple[1];
auto pad_down = pad_tuple[2];
auto pad_right = pad_tuple[3];
if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
// dividing by stride.
pad_down += stride_height - 1;
pad_right += stride_width - 1;
} }
++curr_idx;
Array<Expr> pad_before{ 0, pad_top, pad_left, 0};
Array<Expr> pad_after{ 0, pad_down, pad_right, 0};
auto out_height = tvm::ir::Simplify(
(height - kernel_height + pad_top + pad_down) / stride_height + 1);
auto out_width = tvm::ir::Simplify(
(width - kernel_width + pad_left + pad_right) / stride_width + 1);
auto dheight = tvm::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::reduce_axis(Range(0, kernel_width));
if (pool_type == kMaxPool) {
auto temp = pad(x, pad_before, pad_after, x->dtype.min(), "pad_temp");
return tvm::compute(
{ batch, out_height, out_width, channel },
[&](Var n, Var h, Var w, Var c) {
return tvm::max(temp(n, h * stride_height + dheight, w * stride_width + dwidth, c),
{ dheight, dwidth });
}, "tensor", "pool_max");
} else if (pool_type == kAvgPool) {
auto temp = pad(x, pad_before, pad_after, 0, "pad_temp");
auto tsum = tvm::compute(
{ batch, out_height, out_width, channel },
[&](Var n, Var h, Var w, Var c) {
return tvm::sum(temp(n, h * stride_height + dheight, w * stride_width + dwidth, c),
{ dheight, dwidth });
}, "tensor", "pool_avg");
return tvm::compute(
{ batch, out_height, out_width, channel },
[&](Var n, Var h, Var w, Var c) {
return tsum(n, h, w, c) / (kernel_height * kernel_width);
}, "tensor", kElementWise);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
} }
}
if (*height_axis == -1 || *width_axis == -1) return false;
return true;
} }
/*! /*!
* \brief Perform pooling on data * \brief Perform pooling on height and width dimension of data.
* * It decides the height and width dimension according to the layout string,
* \param x The input tensor in NCHW or NHWC order * in which 'W' and 'H' means width and height respectively.
* Width and height dimension cannot be split.
* For example, NCHW, NCHW16c, etc. are valid for pool,
* while NCHW16w, NCHW16h are not.
* See \a layout for more information of the layout string convention.
* \param x The input tensor.
* \param kernel_size Vector of two ints: {kernel_height, kernel_width} * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
* \param stride_size Vector of two ints: {stride_height, stride_width} * \param stride_size Vector of two ints: {stride_height, stride_width}
* \param padding_size Vector of two ints: {padding_height, padding_width} * \param padding_size Vector of two ints: {padding_height, padding_width}
* \param pool_type The type of pooling operator * \param pool_type The type of pooling operator
* \param ceil_mode Whether to use ceil when calculating the output size * \param ceil_mode Whether to use ceil when calculating the output size
* \param layout The input layout * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
* * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
* \return The output tensor in NCHW order * where upper case indicates a dimension and
* the corresponding lower case (with factor size) indicates the split dimension.
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* (in which factor size `16` will not be used in pooling but for other operators,
* it can be used to decide the output shape).
* Since pooling does not care about the factor size of dimensions
* other than `H` and `W`, one can pass `NCHWc` as well.
* \return The output tensor in the same layout
*/ */
inline Tensor pool(const Tensor& x, inline Tensor pool(const Tensor& x,
const Array<Expr>& kernel_size, const Array<Expr>& kernel_size,
const Array<Expr>& stride_size, const Array<Expr>& stride_size,
...@@ -223,50 +186,79 @@ inline Tensor pool(const Tensor& x, ...@@ -223,50 +186,79 @@ inline Tensor pool(const Tensor& x,
PoolType pool_type, PoolType pool_type,
bool ceil_mode, bool ceil_mode,
const std::string& layout = "NCHW") { const std::string& layout = "NCHW") {
CHECK(layout == "NCHW" || layout == "NHWC") << "Unsupported layout."; int height_axis = -1, width_axis = -1;
if (layout == "NCHW") CHECK(find_height_width(layout, &height_axis, &width_axis))
return pool_nchw(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode); << "Unsupported layout " << layout;
else return pool_impl(x, kernel_size, stride_size, padding_size,
return pool_nhwc(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode); pool_type, ceil_mode, height_axis, width_axis);
} }
/*! /*!
* \brief Perform global pooling on data in NCHW order * \brief Perform global pooling on height and width dimension of data.
* It decides the height and width dimension according to the layout string,
* in which 'W' and 'H' means width and height respectively.
* Width and height dimension cannot be split.
* For example, NCHW, NCHW16c, ... are valid for global_pool,
* while NCHW16w, NCHW16h are not.
* See \a layout for more information of the layout string convention.
* *
* \param x The input tensor in NCHW order * \param x The input tensor represent as layout
* \param pool_type The type of pooling operator * \param pool_type The type of pooling operator
* \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear.
* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
* where upper case indicates a dimension and
* the corresponding lower case (with factor size) indicates the sub-dimension.
* For example, `NCHW16c` can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* (in which factor size `16` will not be used in pooling but for other operators,
* it can be used to decide the output shape).
* Since pooling does not care about the factor size of
* dimensions other than `H` and `W`, one can pass `NCHWc` as well.
* *
* \return The output tensor with shape [batch, channel, 1, 1] * \return The output tensor in same layout with height and width dimension size of 1.
* e.g., for NCHW, the output shape will be [batch, channel, 1, 1]
*/ */
inline Tensor global_pool(const Tensor& x, inline Tensor global_pool(const Tensor& x,
PoolType pool_type) { PoolType pool_type,
CHECK_EQ(x->shape.size(), 4) << "Pooling input must be 4-D"; const std::string& layout = "NCHW") {
CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
int height_axis = -1, width_axis = -1;
CHECK(find_height_width(layout, &height_axis, &width_axis))
<< "Unsupported layout " << layout;
Array<Expr> out_shape = x->shape;
out_shape.Set(height_axis, 1);
out_shape.Set(width_axis, 1);
auto batch = x->shape[0]; auto height = x->shape[height_axis];
auto channel = x->shape[1]; auto width = x->shape[width_axis];
auto height = x->shape[2];
auto width = x->shape[3];
auto dheight = tvm::reduce_axis(Range(0, height)); auto dheight = tvm::reduce_axis(Range(0, height));
auto dwidth = tvm::reduce_axis(Range(0, width)); auto dwidth = tvm::reduce_axis(Range(0, width));
if (pool_type == kMaxPool) { if (pool_type == kMaxPool) {
return tvm::compute( return tvm::compute(out_shape,
{ batch, channel, 1, 1 }, [&](const Array<Var>& output) {
[&](Var n, Var c, Var h, Var w) { Array<Expr> indices;
return tvm::max(x(n, c, dheight, dwidth), { dheight, dwidth }); // NOLINT(*) for (const Var& var : output) indices.push_back(var);
indices.Set(height_axis, dheight);
indices.Set(width_axis, dwidth);
return tvm::max(x(indices), { dheight, dwidth }); // NOLINT(*)
}, "tensor", "global_pool_max"); }, "tensor", "global_pool_max");
} else if (pool_type == kAvgPool) { } else if (pool_type == kAvgPool) {
auto tsum = tvm::compute( auto tsum = tvm::compute(out_shape,
{ batch, channel, 1, 1 }, [&](const Array<Var>& output) {
[&](Var n, Var c, Var h, Var w) { Array<Expr> indices;
return tvm::sum(x(n, c, dheight, dwidth), { dheight, dwidth }); for (const Var& var : output) indices.push_back(var);
indices.Set(height_axis, dheight);
indices.Set(width_axis, dwidth);
return tvm::sum(x(indices), { dheight, dwidth });
}, "tensor", "global_pool_sum"); }, "tensor", "global_pool_sum");
return tvm::compute( return tvm::compute(out_shape,
{ batch, channel, 1, 1 }, [&](const Array<Var>& output) {
[&](Var n, Var c, Var h, Var w) { return tsum(output) / tvm::cast(x->dtype, height * width);
return tsum(n, c, h, w) / tvm::cast(x->dtype, height * width);
}, "tensor", kElementWise); }, "tensor", kElementWise);
} else { } else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type; LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
......
...@@ -84,6 +84,7 @@ def schedule_pool(outs): ...@@ -84,6 +84,7 @@ def schedule_pool(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Pool): def _schedule(PaddedInput, Pool):
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
s[PaddedInput].compute_inline() s[PaddedInput].compute_inline()
num_thread = tvm.target.current_target(allow_none=False).max_num_threads num_thread = tvm.target.current_target(allow_none=False).max_num_threads
if Pool.op in s.outputs: if Pool.op in s.outputs:
......
"""TVM operator pooling compute.""" """TVM operator pooling compute."""
from __future__ import absolute_import from __future__ import absolute_import
import tvm from .. import cpp
from .pad import pad
from .util import get_pad_tuple POOL_TYPE_CODE = {
from .. import util "avg": 0,
from .. import tag "max": 1
}
def global_pool(data, pool_type): def global_pool(data, pool_type, layout="NCHW"):
"""Perform global pooling on the data """Perform global pooling on height and width dimension of data.
It decides the height and width dimension according to the layout string,
in which 'W' and 'H' means width and height respectively.
Width and height dimension cannot be split.
For example, NCHW, NCHW16c, etc. are valid for pool,
while NCHW16w, NCHW16h are not.
See parameter `layout` for more information of the layout string convention.
Parameters Parameters
---------- ----------
data : tvm.Tensor data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width] n-D with shape of layout
pool_type : str pool_type : str
Pool type, 'max' or 'avg' Pool type, 'max' or 'avg'
layout : str
Layout of the input data.
The layout is supposed to be composed of upper cases, lower cases and numbers,
where upper case indicates a dimension and
the corresponding lower case with factor size indicates the split dimension.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block],
in which channel_block=16 is a split of dimension channel.
Returns Returns
------- -------
output : tvm.Tensor output : tvm.Tensor
4-D with shape [batch, channel, 1, 1] n-D in same layout with height and width dimension size of 1.
e.g., for NCHW, the output shape will be [batch, channel, 1, 1]
""" """
assert len(data.shape) == 4, "only support 4-dim pooling" return cpp.nn.global_pool(data, POOL_TYPE_CODE[pool_type], layout)
batch, channel, height, width = data.shape
dheight = tvm.reduce_axis((0, height))
dwidth = tvm.reduce_axis((0, width))
if pool_type == 'max':
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tvm.max(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
tag="global_pool_max")
elif pool_type == 'avg':
tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
tag="global_pool_sum")
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tsum[n, c, h, w] / (height*width).astype(tsum.dtype), \
tag=tag.ELEMWISE)
else:
raise ValueError("Pool type should be 'avg' or 'max'.")
def pool(data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW"): def pool(data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW"):
"""Perform pooling on the data """Perform pooling on height and width dimension of data.
It decides the height and width dimension according to the layout string,
in which 'W' and 'H' means width and height respectively.
Width and height dimension cannot be split.
For example, NCHW, NCHW16c, etc. are valid for pool,
while NCHW16w, NCHW16h are not.
See parameter `layout` for more information of the layout string convention.
Parameters Parameters
---------- ----------
data : tvm.Tensor data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width] n-D with shape of layout
or [batch, in_height, in_width, channel]
kernel : list/tuple of two ints kernel : list/tuple of two ints
Kernel size, [kernel_height, kernel_width] Kernel size, [kernel_height, kernel_width]
...@@ -69,167 +72,18 @@ def pool(data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW ...@@ -69,167 +72,18 @@ def pool(data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW
Whether to use ceil when caculate output size. Whether to use ceil when caculate output size.
layout: string layout: string
either "NCHW" or "NHWC" Layout of the input data.
The layout is supposed to be composed of upper cases, lower cases and numbers,
Returns where upper case indicates a dimension and
------- the corresponding lower case with factor size indicates the split dimension.
output : tvm.Tensor For example, NCHW16c can describe a 5-D tensor of
4-D with shape [batch, channel, out_height, out_width] [batch_size, channel, height, width, channel_block],
or [batch, out_height, out_width, channel] in which channel_block=16 is a split of dimension channel.
"""
if layout == "NCHW":
return pool_nchw(data, kernel, stride, padding, pool_type, ceil_mode=ceil_mode)
elif layout == "NHWC":
return pool_nhwc(data, kernel, stride, padding, pool_type, ceil_mode=ceil_mode)
else:
raise ValueError("not support this layout {} yet".format(layout))
def pool_nchw(data, kernel, stride, padding, pool_type, ceil_mode=False):
"""Perform pooling on the data in NCHW layout
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
kernel : list/tuple of two ints
Kernel size, [kernel_height, kernel_width]
stride : list/tuple of two ints
Stride size, [stride_height, stride_width]
paddding : list/tuple of two ints
Pad size, [pad_height, pad_width]
pool_type : str
Pool type, 'max' or 'avg'
ceil_mode : bool
Whether to use ceil when caculate output size.
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, out_height, out_width]
"""
assert len(data.shape) == 4, "only support 4-dim pooling"
assert len(stride) == 2, "only support 2-dim stride"
kernel_height, kernel_width = kernel
stride_height, stride_width = stride
batch, channel, height, width = data.shape
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_height, kernel_width))
if ceil_mode:
# Additional padding to ensure we do ceil instead of floor when divide stride.
pad_down += stride_height -1
pad_right += stride_width - 1
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
out_height = util.simplify((height - kernel_height + pad_top + pad_down) // stride_height + 1)
out_width = util.simplify((width - kernel_width + pad_left + pad_right) // stride_width + 1)
dheight = tvm.reduce_axis((0, kernel_height))
dwidth = tvm.reduce_axis((0, kernel_width))
if pool_type == 'max':
temp = pad(data, pad_before, pad_after, name="pad_temp", \
pad_value=tvm.min_value(data.dtype))
return tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: \
tvm.max(temp[n, c, h*stride_height+dheight, w*stride_width+dwidth], \
axis=[dheight, dwidth]), \
tag="pool_max")
elif pool_type == 'avg':
temp = pad(data, pad_before, pad_after, name="pad_temp", \
pad_value=tvm.const(0.).astype(data.dtype))
tsum = tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: \
tvm.sum(temp[n, c, h*stride_height+dheight, w*stride_width+dwidth], \
axis=[dheight, dwidth]), \
tag="pool_avg")
return tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: \
tsum[n, c, h, w] / (kernel_height*kernel_width), \
tag=tag.ELEMWISE)
else:
raise ValueError("Pool type should be 'avg' or 'max'.")
def pool_nhwc(data, kernel, stride, padding, pool_type, ceil_mode=False):
"""Perform pooling on the data in NHWC layout
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, in_height, in_width, channel]
kernel : list/tuple of two ints
Kernel size, [kernel_height, kernel_width]
stride : list/tuple of two ints
Stride size, [stride_height, stride_width]
paddding : list/tuple of two ints
Pad size, [pad_height, pad_width]
pool_type : str
Pool type, 'max' or 'avg'
ceil_mode : bool
Whether to use ceil when caculate output size.
Returns Returns
------- -------
output : tvm.Tensor output : tvm.Tensor
4-D with shape [batch, channel, out_height, out_width] n-D in the same layout
""" """
assert len(data.shape) == 4, "only support 4-dim pooling" return cpp.nn.pool(data, kernel, stride, padding,
assert len(stride) == 2, "only support 2-dim stride" POOL_TYPE_CODE[pool_type], ceil_mode, layout)
kernel_height, kernel_width = kernel
stride_height, stride_width = stride
batch, height, width, channel = data.shape
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_height, kernel_width))
if ceil_mode:
# Additional padding to ensure we do ceil instead of floor when divide stride.
pad_down += stride_height -1
pad_right += stride_width - 1
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
out_height = util.simplify((height - kernel_height + pad_top + pad_down) // stride_height + 1)
out_width = util.simplify((width - kernel_width + pad_left + pad_right) // stride_width + 1)
dheight = tvm.reduce_axis((0, kernel_height))
dwidth = tvm.reduce_axis((0, kernel_width))
if pool_type == 'max':
temp = pad(data, pad_before, pad_after, name="pad_temp", \
pad_value=tvm.min_value(data.dtype))
return tvm.compute((batch, out_height, out_width, channel), \
lambda n, h, w, c: \
tvm.max(temp[n, h*stride_height+dheight, w*stride_width+dwidth, c], \
axis=[dheight, dwidth]), \
tag="pool_max")
elif pool_type == 'avg':
temp = pad(data, pad_before, pad_after, name="pad_temp", \
pad_value=tvm.const(0.).astype(data.dtype))
tsum = tvm.compute((batch, out_height, out_width, channel, ), \
lambda n, h, w, c: \
tvm.sum(temp[n, h*stride_height+dheight, w*stride_width+dwidth, c], \
axis=[dheight, dwidth]), \
tag="pool_avg")
return tvm.compute((batch, out_height, out_width, channel), \
lambda n, h, w, c: \
tsum[n, h, w, c] / (kernel_height*kernel_width), \
tag=tag.ELEMWISE)
else:
raise ValueError("Pool type should be 'avg' or 'max'.")
...@@ -67,6 +67,7 @@ def schedule_pool(outs): ...@@ -67,6 +67,7 @@ def schedule_pool(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Pool): def _schedule(PaddedInput, Pool):
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
s[PaddedInput].opengl() s[PaddedInput].opengl()
if Pool.op in s.outputs: if Pool.op in s.outputs:
Out = Pool Out = Pool
......
...@@ -27,6 +27,10 @@ def schedule_injective(outs): ...@@ -27,6 +27,10 @@ def schedule_injective(outs):
n, c, _, _ = s[x].op.axis n, c, _, _ = s[x].op.axis
fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h
s[x].parallel(fused) s[x].parallel(fused)
elif len(s[x].op.axis) == 5:
n, C, h, _, _ = s[x].op.axis
fused = s[x].fuse(n, C, h)
s[x].parallel(fused)
else: else:
s[x].parallel(s[x].op.axis[0]) s[x].parallel(s[x].op.axis[0])
return s return s
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment