Commit 07f12239 by Tatsuya Nishiyama Committed by Tianqi Chen

Add count_include_pad support to AvgPool (#1163)

* Add count_include_pad support to AvgPool

* Fix python_cpp/test_topi_pooling.py

* Change auto to explicitly type, and fix format.
parent aede4820
...@@ -62,3 +62,4 @@ List of Contributors ...@@ -62,3 +62,4 @@ List of Contributors
- [Haolong Zhang](https://github.com/haolongzhangm) - [Haolong Zhang](https://github.com/haolongzhangm)
- [Cody Hao Yu](https://github.com/comaniac) - [Cody Hao Yu](https://github.com/comaniac)
- [Chris Nuernberger](https://github.com/cnuernber) - [Chris Nuernberger](https://github.com/cnuernber)
- [Tatsuya Nishiyama](https://github.com/nishi-t)
...@@ -36,6 +36,7 @@ enum PoolType : int { ...@@ -36,6 +36,7 @@ enum PoolType : int {
* \param ceil_mode Whether to use ceil when calculating the output size * \param ceil_mode Whether to use ceil when calculating the output size
* \param height_axis index of the height dimension * \param height_axis index of the height dimension
* \param width_axis index of the width dimension * \param width_axis index of the width dimension
* \param count_include_pad Whether include padding in the calculation
* *
* \return The output tensor in same layout order * \return The output tensor in same layout order
*/ */
...@@ -46,7 +47,8 @@ inline Tensor pool_impl(const Tensor& x, ...@@ -46,7 +47,8 @@ inline Tensor pool_impl(const Tensor& x,
PoolType pool_type, PoolType pool_type,
bool ceil_mode, bool ceil_mode,
const size_t height_axis, const size_t height_axis,
const size_t width_axis) { const size_t width_axis,
bool count_include_pad) {
CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
...@@ -120,7 +122,19 @@ inline Tensor pool_impl(const Tensor& x, ...@@ -120,7 +122,19 @@ inline Tensor pool_impl(const Tensor& x,
return tvm::compute(out_shape, return tvm::compute(out_shape,
[&](const Array<Var>& output) { [&](const Array<Var>& output) {
return tsum(output) / (kernel_height * kernel_width); if (count_include_pad) {
return tsum(output) / (kernel_height * kernel_width);
} else {
Expr h_start = output[height_axis] * stride_height - padding_height;
Expr w_start = output[width_axis] * stride_width - padding_width;
Expr h_end = ir::Min::make(h_start + kernel_height, height);
Expr w_end = ir::Min::make(w_start + kernel_width, width);
h_start = ir::Max::make(h_start, make_const(Int(32), 0));
w_start = ir::Max::make(w_start, make_const(Int(32), 0));
Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start),
make_const(Int(32), 1));
return tsum(output) / divide_factor;
}
}, "tensor", kElementWise); }, "tensor", kElementWise);
} else { } else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type; LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
...@@ -177,6 +191,9 @@ inline bool find_height_width(const std::string& layout, ...@@ -177,6 +191,9 @@ inline bool find_height_width(const std::string& layout,
* it can be used to decide the output shape). * it can be used to decide the output shape).
* Since pooling does not care about the factor size of dimensions * Since pooling does not care about the factor size of dimensions
* other than `H` and `W`, one can pass `NCHWc` as well. * other than `H` and `W`, one can pass `NCHWc` as well.
* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
*
*
* \return The output tensor in the same layout * \return The output tensor in the same layout
*/ */
inline Tensor pool(const Tensor& x, inline Tensor pool(const Tensor& x,
...@@ -185,12 +202,14 @@ inline Tensor pool(const Tensor& x, ...@@ -185,12 +202,14 @@ inline Tensor pool(const Tensor& x,
const Array<Expr>& padding_size, const Array<Expr>& padding_size,
PoolType pool_type, PoolType pool_type,
bool ceil_mode, bool ceil_mode,
const std::string& layout = "NCHW") { const std::string& layout = "NCHW",
bool count_include_pad = true) {
int height_axis = -1, width_axis = -1; int height_axis = -1, width_axis = -1;
CHECK(find_height_width(layout, &height_axis, &width_axis)) CHECK(find_height_width(layout, &height_axis, &width_axis))
<< "Unsupported layout " << layout; << "Unsupported layout " << layout;
return pool_impl(x, kernel_size, stride_size, padding_size, return pool_impl(x, kernel_size, stride_size, padding_size,
pool_type, ceil_mode, height_axis, width_axis); pool_type, ceil_mode, height_axis, width_axis,
count_include_pad);
} }
/*! /*!
......
...@@ -42,7 +42,14 @@ def global_pool(data, pool_type, layout="NCHW"): ...@@ -42,7 +42,14 @@ def global_pool(data, pool_type, layout="NCHW"):
return cpp.nn.global_pool(data, POOL_TYPE_CODE[pool_type], layout) return cpp.nn.global_pool(data, POOL_TYPE_CODE[pool_type], layout)
def pool(data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW"): def pool(data,
kernel,
stride,
padding,
pool_type,
ceil_mode=False,
layout="NCHW",
count_include_pad=True):
"""Perform pooling on height and width dimension of data. """Perform pooling on height and width dimension of data.
It decides the height and width dimension according to the layout string, It decides the height and width dimension according to the layout string,
in which 'W' and 'H' means width and height respectively. in which 'W' and 'H' means width and height respectively.
...@@ -80,10 +87,13 @@ def pool(data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW ...@@ -80,10 +87,13 @@ def pool(data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW
[batch_size, channel, height, width, channel_block], [batch_size, channel, height, width, channel_block],
in which channel_block=16 is a split of dimension channel. in which channel_block=16 is a split of dimension channel.
count_include_pad: bool
Whether include padding in the calculation when pool_type is 'avg'
Returns Returns
------- -------
output : tvm.Tensor output : tvm.Tensor
n-D in the same layout n-D in the same layout
""" """
return cpp.nn.pool(data, kernel, stride, padding, return cpp.nn.pool(data, kernel, stride, padding,
POOL_TYPE_CODE[pool_type], ceil_mode, layout) POOL_TYPE_CODE[pool_type], ceil_mode, layout, count_include_pad)
...@@ -322,7 +322,7 @@ TVM_REGISTER_GLOBAL("topi.nn.pool") ...@@ -322,7 +322,7 @@ TVM_REGISTER_GLOBAL("topi.nn.pool")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool(args[0], args[1], args[2], args[3], *rv = nn::pool(args[0], args[1], args[2], args[3],
static_cast<nn::PoolType>(static_cast<int>(args[4])), static_cast<nn::PoolType>(static_cast<int>(args[4])),
args[5], args[6]); args[5], args[6], args[7]);
}); });
TVM_REGISTER_GLOBAL("topi.nn.global_pool") TVM_REGISTER_GLOBAL("topi.nn.global_pool")
......
...@@ -5,14 +5,14 @@ import topi ...@@ -5,14 +5,14 @@ import topi
import math import math
from topi.util import get_const_tuple from topi.util import get_const_tuple
def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
iw = ih iw = ih
kw = kh kw = kh
sw = sh sw = sh
ph, pw = padding ph, pw = padding
A = tvm.placeholder((n, ic, ih, iw), name='A') A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode) pool_type=pool_type, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
B = topi.nn.relu(B) B = topi.nn.relu(B)
dtype = A.dtype dtype = A.dtype
...@@ -26,7 +26,7 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): ...@@ -26,7 +26,7 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1) assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1)
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype) pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype)
no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
pad_np[np.ix_(*no_zero)] = a_np pad_np[np.ix_(*no_zero)] = a_np
...@@ -36,7 +36,12 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): ...@@ -36,7 +36,12 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
if pool_type == 'avg': if pool_type == 'avg':
for i in range(oh): for i in range(oh):
for j in range(ow): for j in range(ow):
b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) if count_include_pad:
b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
else:
pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3))
b_np[:,:,i,j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) / np.maximum(pad_count, 1)
elif pool_type =='max': elif pool_type =='max':
for i in range(oh): for i in range(oh):
for j in range(ow): for j in range(ow):
...@@ -62,8 +67,11 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): ...@@ -62,8 +67,11 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
check_device(device) check_device(device)
def test_pool(): def test_pool():
verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False) verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False, True)
verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False) verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False, True)
verify_pool(1, 256, 32, 2, 2, [1, 2], 'avg', False, False)
verify_pool(1, 256, 31, 4, 4, [3, 3], 'avg', False, False)
verify_pool(1, 256, 31, 4, 4, [0, 0], 'avg', False, False)
verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False) verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False) verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True) verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True)
......
...@@ -9,14 +9,14 @@ pool_code = { ...@@ -9,14 +9,14 @@ pool_code = {
"avg": 0, "avg": 0,
"max": 1 "max": 1
} }
def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
iw = ih iw = ih
kw = kh kw = kh
sw = sh sw = sh
ph, pw = padding ph, pw = padding
A = tvm.placeholder((n, ic, ih, iw), name='A') A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.cpp.nn.pool(A, [kh, kw], [sh, sw], padding, B = topi.cpp.nn.pool(A, [kh, kw], [sh, sw], padding,
pool_code[pool_type], ceil_mode, "NCHW") pool_code[pool_type], ceil_mode, "NCHW", count_include_pad)
B = topi.cpp.nn.relu(B) B = topi.cpp.nn.relu(B)
dtype = A.dtype dtype = A.dtype
...@@ -40,7 +40,12 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): ...@@ -40,7 +40,12 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
if pool_type == 'avg': if pool_type == 'avg':
for i in range(oh): for i in range(oh):
for j in range(ow): for j in range(ow):
b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) if count_include_pad:
b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
else:
pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3))
b_np[:,:,i,j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) / np.maximum(pad_count, 1)
elif pool_type =='max': elif pool_type =='max':
for i in range(oh): for i in range(oh):
for j in range(ow): for j in range(ow):
...@@ -68,8 +73,11 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): ...@@ -68,8 +73,11 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
check_device(device) check_device(device)
def test_pool(): def test_pool():
verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False) verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False, True)
verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False) verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False, True)
verify_pool(1, 256, 32, 2, 2, [1, 2], 'avg', False, False)
verify_pool(1, 256, 31, 4, 4, [3, 3], 'avg', False, False)
verify_pool(1, 256, 31, 4, 4, [0, 0], 'avg', False, False)
verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False) verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False) verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True) verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment