Commit 41959ed2 by optima2005 Committed by masahi

[TOPI] implement pool3d op (#4478)

* [TOPI] implement pool3d op

* use PoolInferCorrectLayout for both 2d and 3d pooling

* unify MakeMaxPool and MakeAvgPool
parent 8c2d4f65
......@@ -406,6 +406,68 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
};
/*! \brief Attributes for 3D max pool operator */
struct MaxPool3DAttrs : public tvm::AttrsNode<MaxPool3DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
std::string layout;
bool ceil_mode;
TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") {
TVM_ATTR_FIELD(pool_size)
.describe("Size of the pooling windows.");
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom, right)");
TVM_ATTR_FIELD(layout).set_default("NCDHW")
.describe("Dimension ordering of data and weight. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(ceil_mode).set_default(false)
.describe("When true, will use ceil instead of floor to compute the output shape.");
}
};
/*! \brief Attributes for 3D avg pool operator */
struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
std::string layout;
bool ceil_mode;
bool count_include_pad;
TVM_DECLARE_ATTRS(AvgPool3DAttrs, "relay.attrs.AvgPool3DAttrs") {
TVM_ATTR_FIELD(pool_size)
.describe("Size of the pooling windows.");
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom, right)");
TVM_ATTR_FIELD(layout).set_default("NCDHW")
.describe("Dimension ordering of data and weight. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(ceil_mode).set_default(false)
.describe("When true, will use ceil instead of floor to compute the output shape.");
TVM_ATTR_FIELD(count_include_pad).set_default(false)
.describe("When true, will include padding to compute the average");
}
};
/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
......
......@@ -24,6 +24,7 @@
#ifndef TOPI_NN_POOLING_H_
#define TOPI_NN_POOLING_H_
#include <algorithm>
#include <string>
#include <vector>
......@@ -43,6 +44,7 @@ enum PoolType : int {
kMaxPool,
};
/*!
* \brief Perform pooling on height and width dimension of data.
*
......@@ -325,31 +327,46 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
}
}
inline bool find_height_width(const std::string& layout,
inline bool find_depth_height_width(const std::string& layout,
int* depth_axis,
int* height_axis,
int* width_axis) {
*height_axis = -1, *width_axis = -1;
*depth_axis = -1, *height_axis = -1, *width_axis = -1;
int curr_idx = 0;
for (size_t i = 0; i < layout.size(); ++i) {
if ((layout[i] >= 'A' && layout[i] <= 'Z') ||
(layout[i] >= 'a' && layout[i] <= 'z')) {
if (layout[i] == 'H') {
if (layout[i] == 'D') {
if (*depth_axis != -1) return false;
*depth_axis = curr_idx;
} else if (layout[i] == 'H') {
if (*height_axis != -1) return false;
*height_axis = curr_idx;
} else if (layout[i] == 'W') {
if (*width_axis != -1) return false;
*width_axis = curr_idx;
} else if (layout[i] == 'h' || layout[i] == 'w') {
} else if (layout[i] == 'd' || layout[i] == 'h' || layout[i] == 'w') {
// do not support split on height or width, e.g., NCHW16w
return false;
}
++curr_idx;
}
}
if (*height_axis == -1 || *width_axis == -1) return false;
if (*depth_axis == -1 || *height_axis == -1 || *width_axis == -1) return false;
return true;
}
inline bool find_height_width(const std::string& layout,
int* height_axis,
int* width_axis) {
int dummy;
CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false);
if (*height_axis != -1 && *width_axis != -1) {
return true;
}
return false;
}
/*!
* \brief Perform pooling on height and width dimension of data.
* It decides the height and width dimension according to the layout string,
......@@ -591,6 +608,182 @@ inline Tensor global_pool(const Tensor& x,
return adaptive_pool(x, Array<Expr>{1, 1}, pool_type, layout);
}
/*!
* \brief Perform pooling on N-dimension of data.
*
* \param x The input tensor
* \param kernel_size Vector of N ints
* \param stride_size Vector of N ints
* \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ...,
* head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN]
* \param pool_type The type of pooling operator
* \param ceil_mode Whether to use ceil when calculating the output size
* \param axis Vector of indices for the N dimensions
* \param count_include_pad Whether include padding in the calculation
*
* \return The output tensor in same layout order
*/
inline Tensor pool_impl_nd(const Tensor& x,
const Array<Expr>& kernel_size,
const Array<Expr>& stride_size,
const Array<Expr>& padding_size,
PoolType pool_type,
bool ceil_mode,
const std::vector<int>& axis,
bool count_include_pad) {
int k_size = kernel_size.size();
int x_size = x->shape.size();
CHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel";
CHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of"
" kernel";
CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
Array<IterVar> daxis;
std::vector<Expr> kernel(k_size);
std::vector<Expr> stride(k_size);
std::vector<Expr> pad_head(k_size);
std::vector<Expr> pad_tail(k_size);
Array<Expr> pad_before(std::vector<Expr>(x_size, 0));
Array<Expr> pad_after(std::vector<Expr>(x_size, 0));
Array<Expr> out_shape = x->shape;
bool do_pad = false;
for (int i = 0; i < k_size; i++) {
int ii = axis[i];
kernel[i] = cast(Int(32), kernel_size[i]);
stride[i] = cast(Int(32), stride_size[i]);
pad_head[i] = cast(Int(32), padding_size[i]);
pad_tail[i] = cast(Int(32), padding_size[i + k_size]);
const int64_t *padding0 = as_const_int(pad_head[i]);
const int64_t *padding1 = as_const_int(pad_tail[i]);
do_pad = (do_pad) ? do_pad : ((padding0 && *padding0) || (padding1 && *padding1));
if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
// dividing by stride.
pad_tail[i] += stride[i] - 1;
}
daxis.push_back(tvm::reduce_axis(Range(0, kernel[i])));
pad_before.Set(ii, pad_head[i]);
pad_after.Set(ii, pad_tail[i]);
auto out_dim = tvm::ir::Simplify(
indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);
out_shape.Set(ii, out_dim);
}
if (pool_type == kMaxPool) {
auto temp = do_pad ? pad(x, pad_before, pad_after, x->dtype.min(), "pad_temp") : x;
return tvm::compute(out_shape, [&](const Array<Var>& output) {
Array<Expr> indices;
for (const Var& var : output) indices.push_back(var);
for (int i = 0; i < k_size; i++) {
int ii = axis[i];
indices.Set(ii, output[ii] * stride[i] + daxis[i]);
}
return tvm::max(temp(indices), daxis);
}, "tensor", "pool_max");
} else if (pool_type == kAvgPool) {
// Pad the inputs
auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
// TVM compute for summing the pooling window.
auto pool_sum = tvm::compute(out_shape,
[&](const Array<Var>& output) {
Array<Expr> indices;
for (const Var& var : output) indices.push_back(var);
for (int i = 0; i < k_size; i++) {
int ii = axis[i];
indices.Set(ii, output[ii] * stride[i] + daxis[i]);
}
return tvm::sum(temp(indices), daxis);
}, "tensor", "pool_sum");
// TVM compute for dividing the reduced window sum by kernel size.
return tvm::compute(out_shape,
[&](const Array<Var>& output) {
Array<Expr> indices;
for (const Var& var : output) indices.push_back(var);
if (count_include_pad) {
auto kernel_size = make_const(Int(32), 1);
for (int i = 0; i < k_size; i++) {
kernel_size *= kernel[i];
}
return div(pool_sum(indices), kernel_size);
} else {
std::vector<Expr> start(k_size);
std::vector<Expr> end(k_size);
auto kernel_size = make_const(Int(32), 1);
for (int i = 0; i < k_size; i++) {
int ii = axis[i];
start[i] = output[ii] * stride[i] - pad_head[i];
end[i] = ir::Min::make(start[i] + kernel[i], x->shape[ii]);
start[i] = ir::Max::make(start[i], make_const(Int(32), 0));
kernel_size *= (end[i] - start[i]);
}
Expr divide_factor = ir::Max::make(kernel_size, make_const(Int(32), 1));
return div(pool_sum(indices), divide_factor);
}
}, "tensor", kElementWise);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
}
}
/*!
* \brief Perform pooling on depth, height and width dimension of data.
* It decides the depth, height and width dimension according to the layout string,
* in which 'D', 'W' and 'H' means depth, width and height respectively.
* Depth, Width and height dimension cannot be split.
* For example, NCDHW, NCDHW16c, etc. are valid for pool,
* while NCDHW16d, NCDHW16w or NCDHW16h are not.
* See \a layout for more information of the layout string convention.
* \param x The input tensor.
* \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width}
* \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width}
* \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width,
* tail_pad_depth, tail_pad_height, tail_pad_width}
* \param pool_type The type of pooling operator
* \param ceil_mode Whether to use ceil when calculating the output size
* \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear.
* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
* where upper case indicates a dimension and
* the corresponding lower case (with factor size) indicates the split dimension.
* For example, NCDHW16c can describe a 6-D tensor of
* [batch_size, channel, depth, height, width, channel_block].
* (in which factor size `16` will not be used in pooling but for other operators,
* it can be used to decide the output shape).
* Since pooling does not care about the factor size of dimensions
* other than `D`, `H` and `W`, one can pass `NCDHWc` as well.
* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
*
*
* \return The output tensor in the same layout
*/
inline Tensor pool3d(const Tensor& x,
const Array<Expr>& kernel_size,
const Array<Expr>& stride_size,
const Array<Expr>& padding_size,
PoolType pool_type,
bool ceil_mode,
const std::string& layout = "NCDHW",
bool count_include_pad = true) {
int depth_axis = -1, height_axis = -1, width_axis = -1;
CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
<< "Unsupported layout " << layout;
std::vector<int> axis = {depth_axis, height_axis, width_axis};
return pool_impl_nd(x, kernel_size, stride_size, padding_size,
pool_type, ceil_mode, axis, count_include_pad);
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_POOLING_H_
......@@ -216,3 +216,60 @@ def adaptive_pool(data,
n-D in the same layout
"""
return cpp.nn.adaptive_pool(data, output_size, POOL_TYPE_CODE[pool_type], layout)
def pool3d(data,
kernel,
stride,
padding,
pool_type,
ceil_mode=False,
layout="NCDHW",
count_include_pad=True):
"""Perform pooling on depth, height and width dimension of data.
It decides the depth, height and width dimension according to the layout string,
in which 'D', 'W' and 'H' means depth, width and height respectively.
Depth, width and height dimension cannot be split.
For example, NCDHW, NCDHW16c, etc. are valid for pool,
while NCDHW16d, NCDHW16w, NCDHW16h are not.
See parameter `layout` for more information of the layout string convention.
Parameters
----------
data : tvm.Tensor
n-D with shape of layout
kernel : list/tuple of three ints
Kernel size, [kernel_depth, kernel_height, kernel_width]
stride : list/tuple of three ints
Stride size, [stride_depth, stride_height, stride_width]
padding : list/tuple of six ints
Pad size, [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right]
pool_type : str
Pool type, 'max' or 'avg'
ceil_mode : bool
Whether to use ceil when calculating output size.
layout: string
Layout of the input data.
The layout is supposed to be composed of upper cases, lower cases and numbers,
where upper case indicates a dimension and
the corresponding lower case with factor size indicates the split dimension.
For example, NCDHW16c can describe a 6-D tensor of
[batch_size, channel, depth, height, width, channel_block],
in which channel_block=16 is a split of dimension channel.
count_include_pad: bool
Whether include padding in the calculation when pool_type is 'avg'
Returns
-------
output : tvm.Tensor
n-D in the same layout
"""
return cpp.nn.pool3d(data, kernel, stride, padding,
POOL_TYPE_CODE[pool_type], ceil_mode, layout, count_include_pad)
......@@ -535,6 +535,13 @@ TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool")
args[3]);
});
TVM_REGISTER_GLOBAL("topi.nn.pool3d")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool3d(args[0], args[1], args[2], args[3],
static_cast<nn::PoolType>(static_cast<int>(args[4])),
args[5], args[6], args[7]);
});
/* Ops from nn/softmax.h */
TVM_REGISTER_GLOBAL("topi.nn.softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......@@ -824,7 +831,8 @@ inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder build
TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing));
.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting(
topi::cuda::schedule_injective_from_existing));
/*! \brief Builder function for instantiating dense ops. */
using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
......
......@@ -264,9 +264,96 @@ def test_adaptive_pool():
verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max")
verify_adaptive_pool((1, 5, 46, 97), (4, 96), "avg")
def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
iz = iw = ih
kz = kw = kh
sz = sw = sh
pf, pt, pl, pk, pb, pr = padding
layout = "NCDHW"
A = tvm.placeholder((n, ic, iz, ih, iw), name='A')
B = topi.nn.pool3d(A, kernel=[kz, kh, kw], stride=[sz, sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode,
layout="NCDHW", count_include_pad=count_include_pad)
B = topi.nn.relu(B)
dtype = A.dtype
bshape = get_const_tuple(B.shape)
ashape = get_const_tuple(A.shape)
if ceil_mode:
assert bshape[2] == int(math.ceil(float(ashape[2] - kz + pf + pk) / sz) + 1)
assert bshape[3] == int(math.ceil(float(ashape[3] - kh + pt + pb) / sh) + 1)
assert bshape[4] == int(math.ceil(float(ashape[4] - kw + pl + pr) / sw) + 1)
else:
assert bshape[2] == int(math.floor(float(ashape[2] - kz + pf + pk) / sz) + 1)
assert bshape[3] == int(math.floor(float(ashape[3] - kh + pt + pb) / sh) + 1)
assert bshape[4] == int(math.floor(float(ashape[4] - kw + pl + pr) / sw) + 1)
a_np = np.random.uniform(low=0.001, size=(n, ic, iz, ih, iw)).astype(dtype)
pad_np = np.zeros(shape=(n, ic, iz+pf+pk, ih+pt+pb, iw+pl+pr)).astype(dtype)
no_zero = (range(n), range(ic), (range(pf, iz+pf)), (range(pt, ih+pt)), (range(pl, iw+pl)))
pad_np[np.ix_(*no_zero)] = a_np
_, oc, oz, oh, ow = get_const_tuple(B.shape)
b_np = np.zeros(shape=(n, oc, oz, oh, ow)).astype(dtype)
if pool_type == 'avg':
for k in range(oz):
for i in range(oh):
for j in range(ow):
if count_include_pad:
b_np[:,:,k,i,j] = np.mean( \
pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4))
else:
pad_count = np.sum( \
pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3,4))
b_np[:,:,k,i,j] = np.sum(pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], \
axis=(2,3, 4)) / np.maximum(pad_count, 1)
elif pool_type =='max':
for k in range(oz):
for i in range(oh):
for j in range(ow):
b_np[:,:,k,i,j] = np.max( \
pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4))
b_np = np.maximum(b_np, 0.0)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_pool(B, layout)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in get_all_backend():
check_device(device)
def test_pool3d():
verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'avg', False, True)
verify_pool3d(1, 256, 31, 3, 3, [1, 1, 2, 2, 2, 1], 'avg', False, True)
verify_pool3d(1, 256, 32, 2, 2, [1, 1, 2, 2, 2, 1], 'avg', False, False)
verify_pool3d(1, 256, 31, 4, 4, [3, 3, 3, 3, 3, 3], 'avg', False, False)
verify_pool3d(1, 256, 31, 4, 4, [0, 0, 0, 0, 0, 0], 'avg', False, False)
verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'max', False)
verify_pool3d(1, 256, 31, 3, 3, [2, 2, 1, 1, 1, 2], 'max', False)
verify_pool3d(1, 256, 31, 3, 3, [2, 2, 1, 1, 1, 2], 'max', True)
verify_pool3d(1, 256, 31, 3, 3, [2, 1, 0, 5, 4, 3], 'avg', False, True)
verify_pool3d(1, 256, 32, 2, 2, [0, 5, 4, 3, 2, 1], 'avg', False, False)
verify_pool3d(1, 256, 31, 3, 3, [1, 0, 5, 4, 3, 2], 'max', False)
verify_pool3d(1, 256, 31, 3, 3, [3, 2, 1, 0, 5, 4], 'max', True)
if __name__ == "__main__":
test_pool()
test_pool_grad()
test_global_pool()
test_adaptive_pool()
test_pool3d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment