Commit 8a98a2e7 by Josh Fromm Committed by masahi

[Relay/Topi][Op] 1D Pooling (#4663)

* Added 1D pooling to Topi

* Added 1D pooling relay op and tests.

* Added onnx parsing and tests for maxpool1d and averagepool1d

* formatting

* moved partial import.

* Fixed typo.
parent baae28b2
......@@ -481,6 +481,66 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
};
/*! \brief Attributes for 1D max pool operator */
struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
std::string layout;
bool ceil_mode;
TVM_DECLARE_ATTRS(MaxPool1DAttrs, "relay.attrs.MaxPool1DAttrs") {
TVM_ATTR_FIELD(pool_size)
.describe("Size of the pooling windows.");
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom, right)");
TVM_ATTR_FIELD(layout).set_default("NCW")
.describe("Dimension ordering of data and weight. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the 'W' dimensions.");
TVM_ATTR_FIELD(ceil_mode).set_default(false)
.describe("When true, will use ceil instead of floor to compute the output shape.");
}
};
/*! \brief Attributes for 1D avg pool operator */
struct AvgPool1DAttrs : public tvm::AttrsNode<AvgPool1DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
std::string layout;
bool ceil_mode;
bool count_include_pad;
TVM_DECLARE_ATTRS(AvgPool1DAttrs, "relay.attrs.AvgPool1DAttrs") {
TVM_ATTR_FIELD(pool_size)
.describe("Size of the pooling windows.");
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom, right)");
TVM_ATTR_FIELD(layout).set_default("NCW")
.describe("Dimension ordering of data and weight. Can be 'NCW', 'NHC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the 'W' dimension.");
TVM_ATTR_FIELD(ceil_mode).set_default(false)
.describe("When true, will use ceil instead of floor to compute the output shape.");
TVM_ATTR_FIELD(count_include_pad).set_default(false)
.describe("When true, will include padding to compute the average");
}
};
/*! \brief Attributes for 3D max pool operator */
struct MaxPool3DAttrs : public tvm::AttrsNode<MaxPool3DAttrs> {
Array<IndexExpr> pool_size;
......
......@@ -18,6 +18,7 @@
"""ONNX: Open Neural Network Exchange frontend for Relay."""
from __future__ import absolute_import as _abs
from functools import partial
import numpy as np
import tvm
from ... import nd as _nd
......@@ -43,12 +44,15 @@ def get_numpy(tensor_proto):
def dimension_picker(prefix, surfix=''):
"""Check that dimensions are supported."""
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 1:
return prefix + '1d' + surfix
if len(kernel) == 2:
return prefix + '2d' + surfix
msg = 'Only 2D kernels are supported for operator {}.'
op_name = prefix + '2d'
msg = 'Only 1D and 2D kernels are supported for operator {}.'
op_name = prefix + '1d/2d'
raise tvm.error.OpAttributeInvalid(msg.format(op_name))
return _impl
......@@ -77,21 +81,27 @@ def get_pad_pair(input1d, kernel1d, stride1d):
return [pad_before, pad_after]
def onnx_storage_order2layout(storage_order):
def onnx_storage_order2layout(storage_order, dims=2):
"""converter of onnx storage order parameter to tvm storage order format"""
if storage_order not in (0, 1):
raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1')
return 'NCHW' if storage_order == 0 else 'NHWC'
if dims == 1:
return 'NCW' if storage_order == 0 else 'NWC'
elif dims == 2:
return 'NCHW' if storage_order == 0 else 'NHWC'
else:
msg = "Only 1d and 2d layouts are currently supported"
raise tvm.error.OpAttributeInvalid(msg.format(op_name))
def dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
if len(attrs['kernel_shape']) == 2 or len(attrs['kernel_shape']) == 1:
return True
return False
return _dim_check, "Only 2d kernel supported."
return _dim_check, "Only 1d and 2d kernel supported."
class OnnxOpConverter(object):
......@@ -394,17 +404,33 @@ class MaxPool(Pool):
@classmethod
def _impl_v10(cls, inputs, attr, params):
return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'storage_order': ('layout', 'NCHW', onnx_storage_order2layout),
'ceil_mode': 'ceil_mode'
},
# very weird attributes here in onnx, force check
ignores=['dilations', 'auto_pad'],
custom_check=dimension_constraint())(inputs, attr, params)
input_shape = infer_shape(inputs[0])
# 1D Convolution
if len(input_shape) == 3:
return AttrCvt(
op_name="max_pool1d",
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0)),
'storage_order': ('layout', 'NCW', partial(onnx_storage_order2layout, dims=1)),
'ceil_mode': 'ceil_mode'
},
ignores=['dilations', 'auto_pad'])(inputs, attr, params)
#2D Convolution
if len(input_shape) == 4:
return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'storage_order': ('layout', 'NCHW', onnx_storage_order2layout),
'ceil_mode': 'ceil_mode'
},
# very weird attributes here in onnx, force check
ignores=['dilations', 'auto_pad'],
custom_check=dimension_constraint())(inputs, attr, params)
raise tvm.error.OpAttributeInvalid("Only 1D and 2D maxpooling are currently supported.")
class Mul(Elemwise):
""" Operator converter for Multiply.
......
......@@ -428,6 +428,18 @@ reg.register_schedule("nn.bias_add", schedule_injective)
reg.register_pattern("nn.bias_add", OpPattern.BROADCAST)
# max_pool1d
@reg.register_schedule("nn.max_pool1d")
def schedule_max_pool1d(attrs, outs, target):
"""Schedule definition of max_pool1d"""
layout = attrs.layout
with target:
return topi.generic.schedule_pool(outs, layout)
reg.register_pattern("nn.max_pool1d", OpPattern.OUT_ELEMWISE_FUSABLE)
# max_pool2d
@reg.register_schedule("nn.max_pool2d")
def schedule_max_pool2d(attrs, outs, target):
......@@ -452,6 +464,18 @@ def schedule_max_pool3d(attrs, outs, target):
reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
# avg_pool1d
@reg.register_schedule("nn.avg_pool1d")
def schedule_avg_pool1d(attrs, outs, target):
"""Schedule definition of avg_pool1d"""
layout = attrs.layout
with target:
return topi.generic.schedule_pool(outs, layout)
reg.register_pattern("nn.avg_pool1d", OpPattern.OUT_ELEMWISE_FUSABLE)
# avg_pool2d
@reg.register_schedule("nn.avg_pool2d")
def schedule_avg_pool2d(attrs, outs, target):
......
......@@ -373,6 +373,55 @@ def log_softmax(data, axis=-1):
return _make.log_softmax(data, axis)
def max_pool1d(data,
pool_size=(1,),
strides=(1,),
padding=(0,),
layout="NCW",
ceil_mode=False):
r"""1D maximum pooling operator.
This operator takes data as input and does 1D max value calculation
with in pool_size sized window by striding defined by stride.
In the default case, where the data_layout is `NCW`
a data Tensor with shape `(batch_size, channels, width)`,
to produce an output Tensor.
The ceil_mode is used to take ceil or floor while computing out shape.
count_include_pad indicates including or excluding padded input values in computation.
This operator accepts data layout specification.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
strides : int or tuple of int, optional
The strides of pooling.
padding : int or tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
if isinstance(strides, int):
strides = (strides,)
if isinstance(padding, int):
padding = (padding,)
return _make.max_pool1d(data, pool_size, strides, padding,
layout, ceil_mode)
def max_pool2d(data,
pool_size=(1, 1),
strides=(1, 1),
......@@ -470,6 +519,60 @@ def max_pool3d(data,
return _make.max_pool3d(data, pool_size, strides, padding,
layout, ceil_mode)
def avg_pool1d(data,
pool_size=(1,),
strides=(1,),
padding=(0,),
layout="NCW",
ceil_mode=False,
count_include_pad=False):
r"""1D average pooling operator.
This operator takes data as input and does 1D average value calculation
with in pool_size sized window by striding defined by stride
In the default case, where the data_layout is `NCW`
a data Tensor with shape `(batch_size, channels, width)`,
to produce an output Tensor.
The ceil_mode is used to take ceil or floor while computing out shape.
count_include_pad indicates including or excluding padded input values in computation.
This operator accepts data layout specification.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
strides : int or tuple of int, optional
The strides of pooling.
padding : int or tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
count_include_pad : bool, optional
To include padding to compute the average.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
if isinstance(strides, int):
strides = (strides,)
if isinstance(padding, int):
padding = (padding,)
return _make.avg_pool1d(data, pool_size, strides, padding,
layout, ceil_mode, count_include_pad)
def avg_pool2d(data,
pool_size=(1, 1),
strides=(1, 1),
......@@ -541,7 +644,7 @@ def avg_pool3d(data,
In the default case, where the data_layout is `NCDHW`
a data Tensor with shape `(batch_size, channels, depthm height, width)`,
a data Tensor with shape `(batch_size, channels, depth, height, width)`,
to produce an output Tensor.
The ceil_mode is used to take ceil or floor while computing out shape.
......
......@@ -276,6 +276,16 @@ class AvgPool2DAttrs(Attrs):
@register_relay_attr_node
class MaxPool1DAttrs(Attrs):
"""Attributes used in max_pool1d operators"""
@register_relay_attr_node
class AvgPool1DAttrs(Attrs):
"""Attributes used in avg_pool1d operators"""
@register_relay_attr_node
class MaxPool3DAttrs(Attrs):
"""Attributes used in max_pool3d operators"""
......
......@@ -738,6 +738,184 @@ RELAY_REGISTER_OP("nn.avg_pool2d_grad")
.set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);
// relay.nn.max_pool1d & relay.nn.avg_pool1d
TVM_REGISTER_NODE_TYPE(MaxPool1DAttrs);
TVM_REGISTER_NODE_TYPE(AvgPool1DAttrs);
template <typename AttrType>
bool Pool1DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const auto dshape = data->shape;
CHECK_GE(dshape.size(), 1U)
<< "Pool1D only support input >= 1-D: input must have width";
const auto param = attrs.as<AttrType>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('w')))
<< "Invalid layout " << layout
<< ". Pool1D layout must have W, which cannot be split";
const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
IndexExpr pad_w;
if (param->padding.size() == 1) {
pad_w = param->padding[0] * 2;
} else if (param->padding.size() == 2) {
// (left, right)
pad_w = param->padding[0] + param->padding[1];
} else {
return false;
}
std::vector<IndexExpr> oshape;
for (const auto& e : dshape) {
oshape.push_back(e);
}
if (dshape[widx].as<ir::AnyNode>()) {
oshape[widx] = dshape[widx];
} else {
if (param->ceil_mode) {
oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0] +
param->strides[0] - 1) / param->strides[0]) + 1;
} else {
oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0]) / param->strides[0]) + 1;
}
}
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
template<typename AttrType, topi::nn::PoolType mode>
Array<Tensor> Pool1DCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
static const Layout kNCW("NCW");
const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
auto pool_size = param->pool_size;
auto strides = param->strides;
auto padding = param->padding;
auto ceil_mode = param->ceil_mode;
Layout layout(param->layout);
CHECK(BijectiveLayoutNode::make(layout, kNCW).defined())
<< "max_pool1d currently only supports layouts that are convertible from NCW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool1d does not support input split on width";
CHECK(inputs[0].ndim() == 3U ||
inputs[0].ndim() == 4U ||
inputs[0].ndim() == 5U)
<< "Pool1D only support 3-D input (e.g., NCW)"
<< " or 4-D input (e.g. NCWc on for vector instructions)"
<< " or 5-D input (e.g. NCWnc for tensor accelerators)";
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool1DAttrs*>(param)->count_include_pad;
return Array<Tensor>{
topi::nn::pool1d(inputs[0], pool_size, strides, padding,
mode, ceil_mode, layout.name(), count_include_pad)};
} else {
return Array<Tensor>{
topi::nn::pool1d(inputs[0], pool_size, strides, padding,
mode, ceil_mode, layout.name())};
}
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool1d")
.set_body_typed([](Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode) {
return MakeMaxPool<MaxPool1DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
"nn.max_pool1d");
});
RELAY_REGISTER_OP("nn.max_pool1d")
.describe(R"code(Max pooling operation for one dimensional data.
- **data**: This depends on the `layout` parameter. Input is 3D array of shape
(batch_size, channels, width) if `layout` is `NCW`.
- **out**: This depends on the `layout` parameter. Output is 3D array of shape
(batch_size, channels, , out_width) if `layout` is `NCW`.
out_width is calculated as::
out_width = floor((width+padding[0]+padding[1]-pool_size[0])/strides[0])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int: padding width in the order of (left, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
)code" TVM_ADD_FILELINE)
.set_attrs_type<MaxPool1DAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("MaxPool1D", Pool1DRel<MaxPool1DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<MaxPool1DAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", Pool1DCompute<MaxPool1DAttrs, topi::nn::kMaxPool>);
// AvgPool1D
TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool1d")
.set_body_typed([](Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode,
bool count_include_pad) {
return MakeAvgPool<AvgPool1DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
count_include_pad, "nn.avg_pool1d");
});
RELAY_REGISTER_OP("nn.avg_pool1d")
.describe(R"code(
Average pooling operation for one dimensional data.
- **data**: This depends on the `layout` parameter. Input is 3D array of shape
(batch_size, channels, width) if `layout` is `NCW`.
- **out**: This depends on the `layout` parameter. Output is 3D array of shape
(batch_size, channels, out_width) if `layout` is `NCW`.
out_width is calculated as::
out_width = floor((width+padding[0]+padding[1]-pool_size[0])/strides[0])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int: padding width in the order of (left, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
)code" TVM_ADD_FILELINE)
.set_attrs_type<AvgPool1DAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("AvgPool1D", Pool1DRel<AvgPool1DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<AvgPool1DAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", Pool1DCompute<AvgPool1DAttrs, topi::nn::kAvgPool>);
// relay.nn.max_pool3d & relay.nn.avg_pool3d
TVM_REGISTER_NODE_TYPE(MaxPool3DAttrs);
TVM_REGISTER_NODE_TYPE(AvgPool3DAttrs);
......
......@@ -1827,6 +1827,100 @@ def test_unsqueeze_constant():
relay.frontend.from_onnx(onnx_model, {'0': input_size})
def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode):
x_np = np.random.uniform(size=x_shape).astype('float32')
if mode == 'max':
node_type = "MaxPool"
elif mode == 'average':
node_type = "AveragePool"
else:
raise ValueError("Pool method {} is not supported.".format(mode))
pool_node = helper.make_node(node_type,
inputs=["x"],
outputs=["y"],
kernel_shape=kernel_shape,
pads=pads,
strides=strides)
graph = helper.make_graph([pool_node],
"pooling_test",
inputs=[helper.make_tensor_value_info("x",
TensorProto.FLOAT, list(x_shape))],
outputs=[helper.make_tensor_value_info("y",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='pooling_test')
for target, ctx in ctx_list():
onnx_out = get_onnxruntime_output(model, x_np, 'float32')
tvm_out = get_tvm_output(
model, [x_np], target, ctx, out_shape)
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_pooling():
# MaxPool1D
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[1],
pads=[1, 1],
out_shape=[1, 1, 32],
mode='max')
# MaxPool2D
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[1, 1],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 32, 32],
mode='max')
#AveragePool1D
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[1],
pads=[1, 1],
out_shape=[1, 1, 32],
mode='average')
#AveragePool2D
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[1, 1],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 32, 32],
mode='average')
# MaxPool1D with stride
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[2],
pads=[1, 1],
out_shape=[1, 1, 16],
mode='max')
# MaxPool2D with stride
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[2, 2],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 16, 16],
mode='max')
#AveragePool1D with stride
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[2],
pads=[1, 1],
out_shape=[1, 1, 16],
mode='average')
#AveragePool2D with stride
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[2, 2],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 16, 16],
mode='average')
if __name__ == '__main__':
test_flatten()
test_reshape()
......@@ -1884,3 +1978,4 @@ if __name__ == '__main__':
test_conv()
test_convtranspose()
test_unsqueeze_constant()
test_pooling()
......@@ -642,6 +642,34 @@ def test_pool2d():
_test_global_pool2d(relay.nn.global_avg_pool2d, np.mean)
def test_pool1d():
def _test_pool1d(opfunc):
n, c, w = tvm.var("n"), 10, 224
x = relay.var("x", relay.TensorType((n, c, w), "float32"))
y = opfunc(x, pool_size=(1,))
assert "pool_size=" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 10, 224), "float32")
# test execution
dtype = "float32"
dshape = (1, 3, 32)
x = relay.var("x", shape=dshape)
pool_type = 'max' if 'max' in str(opfunc) else 'avg'
y = opfunc(x, pool_size=(2,), strides=(2,), padding=(0, 0))
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = topi.testing.pool1d_ncw_python(data, (2,), (2,),
(0, 0), (1, 3, 16), pool_type, False)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
_test_pool1d(relay.nn.max_pool1d)
_test_pool1d(relay.nn.avg_pool1d)
def test_pool3d():
def _test_pool3d(opfunc):
......@@ -1081,6 +1109,7 @@ def test_bitpack_infer_type():
if __name__ == "__main__":
test_pool1d()
test_pool2d()
test_pool3d()
test_avg_pool2d_no_count_pad()
......
......@@ -372,6 +372,16 @@ inline bool find_height_width(const std::string& layout,
return false;
}
inline bool find_width(const std::string& layout,
int* width_axis) {
int dummy;
CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false);
if (*width_axis != -1) {
return true;
}
return false;
}
/*!
* \brief Perform pooling on height and width dimension of data.
* It decides the height and width dimension according to the layout string,
......@@ -745,6 +755,51 @@ inline Tensor pool_impl_nd(const Tensor& x,
}
/*!
* \brief Perform pooling on the width dimension of data.
* Width axis is determined by the layout string
* in which 'W' means width.
* Width dimension cannot be split.
* For example, NCW, NCW16c, etc. are valid for pool,
* while NCW16w is not.
* See \a layout for more information of the layout string convention.
* \param x The input tensor.
* \param kernel_size Vector of three ints: {kernel_width}
* \param stride_size Vector of three ints: {stride_width}
* \param padding_size Vector of six ints: {head_pad_width, tail_pad_width}
* \param pool_type The type of pooling operator
* \param ceil_mode Whether to use ceil when calculating the output size
* \param layout The input layout. Pooling supports any layout as long as 'W' appears.
* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
* where upper case indicates a dimension and
* the corresponding lower case (with factor size) indicates the split dimension.
* For example, NCW16c can describe a 4-D tensor of
* [batch_size, channel, width, channel_block].
* (in which factor size `16` will not be used in pooling but for other operators,
* it can be used to decide the output shape).
* Since pooling does not care about the factor size of dimensions
* other than `W`, one can pass `NCWc` as well.
* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
*
*
* \return The output tensor in the same layout
*/
inline Tensor pool1d(const Tensor& x,
const Array<Expr>& kernel_size,
const Array<Expr>& stride_size,
const Array<Expr>& padding_size,
PoolType pool_type,
bool ceil_mode,
const std::string& layout = "NCW",
bool count_include_pad = true) {
int width_axis = -1;
CHECK(find_width(layout, &width_axis))
<< "Unsupported layout " << layout;
std::vector<int> axis = {width_axis};
return pool_impl_nd(x, kernel_size, stride_size, padding_size,
pool_type, ceil_mode, axis, count_include_pad);
}
/*!
* \brief Perform pooling on depth, height and width dimension of data.
* It decides the depth, height and width dimension according to the layout string,
* in which 'D', 'W' and 'H' means depth, width and height respectively.
......
......@@ -218,6 +218,67 @@ def adaptive_pool(data,
return cpp.nn.adaptive_pool(data, output_size, POOL_TYPE_CODE[pool_type], layout)
def pool1d(data,
kernel,
stride,
padding,
pool_type,
ceil_mode=False,
layout="NCW",
count_include_pad=True):
"""Perform pooling on width dimension of data.
Width axis is determined according to the layout string.
in which 'w' means width.
Width dimension cannot be split.
For example, NCW, NCW16c, etc. are valid for pool,
while NCW16w is not.
See parameter `layout` for more information of the layout string convention.
Parameters
----------
data : tvm.Tensor
n-D with shape of layout
kernel : list/tuple of one int or int
Kernel size, [kernel_width]
stride : list/tuple of one int or int
Stride size, [stride_width]
padding : list/tuple of two ints
Pad size, [pad_left, pad_right]
pool_type : str
Pool type, 'max' or 'avg'
ceil_mode : bool
Whether to use ceil when calculating output size.
layout: string
Layout of the input data.
The layout is supposed to be composed of upper cases, lower cases and numbers,
where upper case indicates a dimension and
the corresponding lower case with factor size indicates the split dimension.
For example, NCW16c can describe a 4-D tensor of
[batch_size, channel, width, channel_block],
in which channel_block=16 is a split of dimension channel.
count_include_pad: bool
Whether include padding in the calculation when pool_type is 'avg'
Returns
-------
output : tvm.Tensor
n-D in the same layout
"""
if isinstance(kernel, int):
kernel = [kernel, ]
if isinstance(stride, int):
stride = [stride, ]
return cpp.nn.pool1d(data, kernel, stride, padding,
POOL_TYPE_CODE[pool_type], ceil_mode, layout, count_include_pad)
def pool3d(data,
kernel,
stride,
......
......@@ -45,6 +45,7 @@ from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
from .pool1d_python import pool1d_ncw_python
from .pool3d_python import pool3d_ncdhw_python
from .pool_grad_python import pool_grad_nchw
from .one_hot import one_hot
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, unused-variable
"""max_pool1d and avg_pool1d in python"""
import math
import numpy as np
def pool1d_ncw_python(np_data, kernel,
strides, padding,
out_shape, pool_type,
count_include_pad=True,
ceil_mode=False, dtype="float32"):
"""Baseline for max_pool1d and avg_pool1d, default layout is NCW"""
in_n, in_c, in_w = in_shape = np_data.shape
k_w = kernel[0]
s_w = strides[0]
pl, pr = padding
if ceil_mode:
assert out_shape[2] == int(
math.ceil(float(in_shape[2] - k_w + pl + pr) / s_w) + 1)
else:
assert out_shape[2] == int(math.floor(
float(in_shape[2] - k_w + pl + pr) / s_w) + 1)
pad_np = np.zeros(shape=(in_n, in_c, in_w + pl + pr)).astype(dtype)
no_zero = (range(in_n), range(in_c), range(pl, in_w + pl))
pad_np[np.ix_(*no_zero)] = np_data
ret_np = np.zeros(shape=out_shape).astype(dtype)
if pool_type == 'avg':
for k in range(out_shape[2]):
if count_include_pad:
ret_np[:, :, k] = np.mean(
pad_np[:, :, k * s_w: k * s_w + k_w], axis=(2,))
else:
pad_count = np.sum(
pad_np[:, :, k * s_w: k * s_w + k_w] > 0, axis=(2,))
ret_np[:, :, k] = np.sum(
pad_np[:, :, k * s_w: k * s_w + k_w], axis=(2,)) / np.maximum(pad_count, 1)
elif pool_type == 'max':
for k in range(out_shape[2]):
ret_np[:, :, k] = np.max(pad_np[:, :, k * s_w: k * s_w + k_w], axis=(2,))
else:
raise ValueError("Pool type {} is not supported".format(pool_type))
ret_np = np.maximum(ret_np, 0.0)
return ret_np
......@@ -537,6 +537,13 @@ TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool")
args[3]);
});
TVM_REGISTER_GLOBAL("topi.nn.pool1d")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool1d(args[0], args[1], args[2], args[3],
static_cast<nn::PoolType>(static_cast<int>(args[4])),
args[5], args[6], args[7]);
});
TVM_REGISTER_GLOBAL("topi.nn.pool3d")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool3d(args[0], args[1], args[2], args[3],
......
......@@ -318,8 +318,61 @@ def test_pool3d():
verify_pool3d(1, 256, 31, 3, 3, [3, 2, 1, 0, 5, 4], 'max', True)
def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type,
ceil_mode, count_include_pad=True, layout='NCW'):
input_shape = (n, ic, iw)
kernel = [kw]
stride = [sw]
A = tvm.placeholder(input_shape, name='A')
B = topi.nn.pool1d(A, kernel=kernel, stride=stride, padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode,
layout=layout, count_include_pad=count_include_pad)
B = topi.nn.relu(B)
dtype = A.dtype
output_shape = [int(i) for i in B.shape]
input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype)
ref_np = topi.testing.pool1d_ncw_python(input_np, kernel, stride, padding,
output_shape, pool_type, count_include_pad, ceil_mode)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_pool(B, layout)
a = tvm.nd.array(input_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5)
for device in get_all_backend():
check_device(device)
def test_pool1d():
verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'avg', False, True)
verify_pool1d(1, 256, 31, 3, 3, [1, 2], 'avg', False, True)
verify_pool1d(1, 256, 32, 2, 2, [1, 2], 'avg', False, False)
verify_pool1d(1, 256, 31, 4, 4, [3, 3], 'avg', False, False)
verify_pool1d(1, 256, 31, 4, 4, [0, 0], 'avg', False, False)
verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'max', False)
verify_pool1d(1, 256, 31, 3, 3, [2, 1], 'max', False)
verify_pool1d(1, 256, 31, 3, 3, [2, 1], 'max', True)
verify_pool1d(1, 256, 31, 3, 3, [2, 5], 'avg', False, True)
verify_pool1d(1, 256, 32, 2, 2, [0, 3], 'avg', False, False)
verify_pool1d(1, 256, 31, 3, 3, [1, 4], 'max', False)
verify_pool1d(1, 256, 31, 3, 3, [3, 0], 'max', True)
if __name__ == "__main__":
test_pool()
test_pool1d()
test_pool3d()
test_pool_grad()
test_global_pool()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment