Commit f277da76 by Yong Wu Committed by masahi

[Relay] add max_pool3d in relay and TF converter (#4551)

* [Relay] add max_pool3d in relay and TF converter

* fix comments
parent e6ff3f70
...@@ -71,7 +71,9 @@ This level enables typical convnet models. ...@@ -71,7 +71,9 @@ This level enables typical convnet models.
tvm.relay.nn.conv2d_transpose tvm.relay.nn.conv2d_transpose
tvm.relay.nn.dense tvm.relay.nn.dense
tvm.relay.nn.max_pool2d tvm.relay.nn.max_pool2d
tvm.relay.nn.max_pool3d
tvm.relay.nn.avg_pool2d tvm.relay.nn.avg_pool2d
tvm.relay.nn.avg_pool3d
tvm.relay.nn.global_max_pool2d tvm.relay.nn.global_max_pool2d
tvm.relay.nn.global_avg_pool2d tvm.relay.nn.global_avg_pool2d
tvm.relay.nn.upsampling tvm.relay.nn.upsampling
...@@ -246,7 +248,9 @@ Level 2 Definitions ...@@ -246,7 +248,9 @@ Level 2 Definitions
.. autofunction:: tvm.relay.nn.conv2d_transpose .. autofunction:: tvm.relay.nn.conv2d_transpose
.. autofunction:: tvm.relay.nn.dense .. autofunction:: tvm.relay.nn.dense
.. autofunction:: tvm.relay.nn.max_pool2d .. autofunction:: tvm.relay.nn.max_pool2d
.. autofunction:: tvm.relay.nn.max_pool3d
.. autofunction:: tvm.relay.nn.avg_pool2d .. autofunction:: tvm.relay.nn.avg_pool2d
.. autofunction:: tvm.relay.nn.avg_pool3d
.. autofunction:: tvm.relay.nn.global_max_pool2d .. autofunction:: tvm.relay.nn.global_max_pool2d
.. autofunction:: tvm.relay.nn.global_avg_pool2d .. autofunction:: tvm.relay.nn.global_avg_pool2d
.. autofunction:: tvm.relay.nn.upsampling .. autofunction:: tvm.relay.nn.upsampling
......
...@@ -135,8 +135,10 @@ FUNC_OPS = { ...@@ -135,8 +135,10 @@ FUNC_OPS = {
"nn.dense": op.nn.dense, "nn.dense": op.nn.dense,
"nn.bias_add": op.nn.bias_add, "nn.bias_add": op.nn.bias_add,
"nn.max_pool2d": op.nn.max_pool2d, "nn.max_pool2d": op.nn.max_pool2d,
"nn.max_pool3d": op.nn.max_pool3d,
"nn.global_max_pool2d": op.nn.global_max_pool2d, "nn.global_max_pool2d": op.nn.global_max_pool2d,
"nn.avg_pool2d": op.nn.avg_pool2d, "nn.avg_pool2d": op.nn.avg_pool2d,
"nn.avg_pool3d": op.nn.avg_pool3d,
"nn.global_avg_pool2d": op.nn.global_avg_pool2d, "nn.global_avg_pool2d": op.nn.global_avg_pool2d,
"nn.softmax": op.nn.softmax, "nn.softmax": op.nn.softmax,
"reshape": op.reshape, "reshape": op.reshape,
......
...@@ -122,6 +122,70 @@ def _elemwise(name): ...@@ -122,6 +122,70 @@ def _elemwise(name):
return get_relay_op(name)(*inputs) return get_relay_op(name)(*inputs)
return _impl return _impl
def _pool3d(name):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False
input_shape = attr['_input_shapes'][inputs[0]]
if attr['data_format'] == 'NDHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2], attr['ksize'][3])
attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3])
elif attr['data_format'] == 'NCDHW':
attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3], attr['ksize'][4])
attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4])
else:
msg = 'Value {} of attribute "data_format" of operator Pooling ' \
'is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if attr['data_format'] == "NDHWC":
input_shape = [attr['_input_shapes'][inputs[0]][i] for i in (0, 4, 1, 2, 3)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3))
attr['data_format'] = "NCDHW"
attr['_input_shapes'][inputs[0]] = input_shape
flip_layout = True
attr['padding'] = attr['padding'].decode("utf-8")
if attr['padding'] == 'VALID':
attr['padding'] = [0, 0, 0, 0, 0, 0]
elif attr['padding'] == 'SAME':
stride_d, stride_h, stride_w = attr['strides']
kernel_d, kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NDHWC':
in_d = input_shape[1]
in_h = input_shape[2]
in_w = input_shape[3]
else:
in_d = input_shape[2]
in_h = input_shape[3]
in_w = input_shape[4]
pad_d = _get_pad_pair(in_d, kernel_d, stride_d)
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]]
else:
msg = 'Value {} in attribute "padding" of operator Pooling is ' \
'not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
if name == "avg_pool":
attr['count_include_pad'] = False
attr['ceil_mode'] = False
out = AttrCvt(
op_name=name,
transforms={
'kernel_shape': 'pool_size',
'data_format': 'layout'},
ignores=['ksize'])(inputs, attr)
if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 4, 1))
return out
return _impl
def _pooling(name): def _pooling(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
...@@ -1409,6 +1473,7 @@ _convert_map = { ...@@ -1409,6 +1473,7 @@ _convert_map = {
'ArgMin' : _argx(_op.argmin, 'argmin'), 'ArgMin' : _argx(_op.argmin, 'argmin'),
'Assert' : _assert(), 'Assert' : _assert(),
'AvgPool' : _pooling('avg_pool'), 'AvgPool' : _pooling('avg_pool'),
'AvgPool3D' : _pool3d('avg_pool3d'),
'BatchMatMul' : _batch_matmul(), 'BatchMatMul' : _batch_matmul(),
'BatchMatMulV2' : _batch_matmul(), 'BatchMatMulV2' : _batch_matmul(),
'BatchNormWithGlobalNormalization' : _batch_norm(), 'BatchNormWithGlobalNormalization' : _batch_norm(),
...@@ -1460,6 +1525,7 @@ _convert_map = { ...@@ -1460,6 +1525,7 @@ _convert_map = {
'MatMul' : _matmul(), 'MatMul' : _matmul(),
'Max' : _reduce('max'), 'Max' : _reduce('max'),
'MaxPool' : _pooling('max_pool'), 'MaxPool' : _pooling('max_pool'),
'MaxPool3D' : _pool3d('max_pool3d'),
'Maximum' : _elemwise('maximum'), 'Maximum' : _elemwise('maximum'),
'Mean' : _mean(), 'Mean' : _mean(),
'Min' : _reduce('min'), 'Min' : _reduce('min'),
......
...@@ -396,6 +396,18 @@ def schedule_max_pool2d(attrs, outs, target): ...@@ -396,6 +396,18 @@ def schedule_max_pool2d(attrs, outs, target):
reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# max_pool3d
@reg.register_schedule("nn.max_pool3d")
def schedule_max_pool3d(attrs, outs, target):
"""Schedule definition of max_pool3d"""
layout = attrs.layout
with target:
return topi.generic.schedule_pool(outs, layout)
reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
# avg_pool2d # avg_pool2d
@reg.register_schedule("nn.avg_pool2d") @reg.register_schedule("nn.avg_pool2d")
def schedule_avg_pool2d(attrs, outs, target): def schedule_avg_pool2d(attrs, outs, target):
...@@ -404,10 +416,21 @@ def schedule_avg_pool2d(attrs, outs, target): ...@@ -404,10 +416,21 @@ def schedule_avg_pool2d(attrs, outs, target):
with target: with target:
return topi.generic.schedule_pool(outs, layout) return topi.generic.schedule_pool(outs, layout)
reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# avg_pool3d
@reg.register_schedule("nn.avg_pool3d")
def schedule_avg_pool3d(attrs, outs, target):
"""Schedule definition of avg_pool3d"""
layout = attrs.layout
with target:
return topi.generic.schedule_pool(outs, layout)
reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
# max_pool2d_grad # max_pool2d_grad
@reg.register_schedule("nn.max_pool2d_grad") @reg.register_schedule("nn.max_pool2d_grad")
def schedule_max_pool2d_grad(attrs, outs, target): def schedule_max_pool2d_grad(attrs, outs, target):
......
...@@ -425,6 +425,51 @@ def max_pool2d(data, ...@@ -425,6 +425,51 @@ def max_pool2d(data,
return _make.max_pool2d(data, pool_size, strides, padding, return _make.max_pool2d(data, pool_size, strides, padding,
layout, ceil_mode) layout, ceil_mode)
def max_pool3d(data,
pool_size=(1, 1, 1),
strides=(1, 1, 1),
padding=(0, 0, 0),
layout="NCDHW",
ceil_mode=False):
r"""3D maximum pooling operator.
This operator takes data as input and does 3D max value calculation
with in pool_size sized window by striding defined by stride.
In the default case, where the data_layout is `NCDHW`
a data Tensor with shape `(batch_size, channels, depth, height, width)`,
to produce an output Tensor.
The ceil_mode is used to take ceil or floor while computing out shape.
count_include_pad indicates including or excluding padded input values in computation.
This operator accepts data layout specification.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
strides : tuple of int, optional
The strides of pooling.
padding : tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.max_pool3d(data, pool_size, strides, padding,
layout, ceil_mode)
def avg_pool2d(data, def avg_pool2d(data,
pool_size=(1, 1), pool_size=(1, 1),
strides=(1, 1), strides=(1, 1),
...@@ -482,6 +527,55 @@ def avg_pool2d(data, ...@@ -482,6 +527,55 @@ def avg_pool2d(data,
return _make.avg_pool2d(data, pool_size, strides, padding, return _make.avg_pool2d(data, pool_size, strides, padding,
layout, ceil_mode, count_include_pad) layout, ceil_mode, count_include_pad)
def avg_pool3d(data,
pool_size=(1, 1, 1),
strides=(1, 1, 1),
padding=(0, 0, 0),
layout="NCDHW",
ceil_mode=False,
count_include_pad=False):
r"""3D average pooling operator.
This operator takes data as input and does 3D average value calculation
with in pool_size sized window by striding defined by stride
In the default case, where the data_layout is `NCDHW`
a data Tensor with shape `(batch_size, channels, depthm height, width)`,
to produce an output Tensor.
The ceil_mode is used to take ceil or floor while computing out shape.
count_include_pad indicates including or excluding padded input values in computation.
This operator accepts data layout specification.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
strides : tuple of int, optional
The strides of pooling.
padding : tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
count_include_pad : bool, optional
To include padding to compute the average.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.avg_pool3d(data, pool_size, strides, padding,
layout, ceil_mode, count_include_pad)
def max_pool2d_grad(out_grad, def max_pool2d_grad(out_grad,
data, data,
pool_size=(1, 1), pool_size=(1, 1),
......
...@@ -272,6 +272,16 @@ class AvgPool2DAttrs(Attrs): ...@@ -272,6 +272,16 @@ class AvgPool2DAttrs(Attrs):
@register_relay_attr_node @register_relay_attr_node
class MaxPool3DAttrs(Attrs):
"""Attributes used in max_pool3d operators"""
@register_relay_attr_node
class AvgPool3DAttrs(Attrs):
"""Attributes used in avg_pool3d operators"""
@register_relay_attr_node
class BitPackAttrs(Attrs): class BitPackAttrs(Attrs):
"""Attributes used in bitpack operator""" """Attributes used in bitpack operator"""
......
...@@ -237,7 +237,8 @@ def _test_pooling_iteration(input_shape, **kwargs): ...@@ -237,7 +237,8 @@ def _test_pooling_iteration(input_shape, **kwargs):
def _test_pooling(input_shape, **kwargs): def _test_pooling(input_shape, **kwargs):
_test_pooling_iteration(input_shape, **kwargs) _test_pooling_iteration(input_shape, **kwargs)
if is_gpu_available() and (len(input_shape) == 4): if is_gpu_available():
if len(input_shape) == 4:
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_format'] = 'NCHW' kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs) _test_pooling_iteration(input_shape, **kwargs)
...@@ -245,8 +246,49 @@ def _test_pooling(input_shape, **kwargs): ...@@ -245,8 +246,49 @@ def _test_pooling(input_shape, **kwargs):
def test_forward_pooling(): def test_forward_pooling():
""" Pooling """ """ Pooling """
# TensorFlow only supports NDHWC for max_pool3d on CPU
for pool_type in ['AVG', 'MAX']: for pool_type in ['AVG', 'MAX']:
# NDHWC is the default layout for max_pool3d and avg_pool3d in TensorFlow
_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[2, 2, 2],
padding='VALID',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[2, 2, 2])
_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[1, 1, 1],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[1, 1, 1])
_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[2, 2, 2],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[2, 2, 2])
# test cases for max_pool3d & avg_pool3d with layout NCDHW
# TensorFlow pool3d doesn't support NCDHW on cpu
if is_gpu_available():
_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[1, 1, 1],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[1, 1, 1],
data_format='NCDHW')
_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[2, 2, 2],
padding='VALID',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[2, 2, 2],
data_format='NCDHW')
_test_pooling(input_shape=[2, 9, 10, 2], _test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1], window_shape=[1, 1],
padding='SAME', padding='SAME',
...@@ -2855,7 +2897,6 @@ if __name__ == '__main__': ...@@ -2855,7 +2897,6 @@ if __name__ == '__main__':
test_forward_sin() test_forward_sin()
test_forward_negative() test_forward_negative()
test_forward_divide() test_forward_divide()
test_forward_floordiv()
test_forward_abs() test_forward_abs()
test_forward_softplus() test_forward_softplus()
test_forward_sqrt() test_forward_sqrt()
...@@ -2916,5 +2957,3 @@ if __name__ == '__main__': ...@@ -2916,5 +2957,3 @@ if __name__ == '__main__':
test_forward_where() test_forward_where()
test_forward_matmul() test_forward_matmul()
test_forward_batch_matmul() test_forward_batch_matmul()
# TODO missing tests: rank
...@@ -471,7 +471,7 @@ def _test_pool2d(opfunc, reffunc): ...@@ -471,7 +471,7 @@ def _test_pool2d(opfunc, reffunc):
y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
func = relay.Function([x], y) func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype) data = np.random.uniform(size=dshape).astype(dtype)
ref_res = reffunc(data.reshape(1,3,14,2,14,2), axis=(3,5)) ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5))
for target, ctx in ctx_list(): for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data) op_res1 = intrp1.evaluate(func)(data)
...@@ -532,6 +532,34 @@ def test_pool2d(): ...@@ -532,6 +532,34 @@ def test_pool2d():
_test_global_pool2d(relay.nn.global_avg_pool2d, np.mean) _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean)
def test_pool3d():
def _test_pool3d(opfunc):
n, c, d, h, w = tvm.var("n"), 10, 5, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
y = opfunc(x, pool_size=(1, 1, 1))
assert "pool_size=" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 10, 5, 224, 224), "float32")
# test execution
dtype = "float32"
dshape = (1, 3, 32, 32, 32)
x = relay.var("x", shape=dshape)
pool_type = 'max' if 'max' in str(opfunc) else 'avg'
y = opfunc(x, pool_size=(2, 2, 2), strides=(2, 2, 2), padding=(0, 0, 0, 0, 0, 0))
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = topi.testing.pool3d_ncdhw_python(data, (2, 2, 2), (2, 2, 2),
(0, 0, 0, 0, 0, 0), (1, 3, 16, 16, 16), pool_type, False)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
_test_pool3d(relay.nn.max_pool3d)
_test_pool3d(relay.nn.avg_pool3d)
def test_avg_pool2d_no_count_pad(): def test_avg_pool2d_no_count_pad():
kh, kw = (4, 4) kh, kw = (4, 4)
sh, sw = (2, 2) sh, sw = (2, 2)
...@@ -900,6 +928,7 @@ def test_bitpack_infer_type(): ...@@ -900,6 +928,7 @@ def test_bitpack_infer_type():
if __name__ == "__main__": if __name__ == "__main__":
test_pool2d() test_pool2d()
test_pool3d()
test_avg_pool2d_no_count_pad() test_avg_pool2d_no_count_pad()
test_lrn() test_lrn()
test_l2_normalize() test_l2_normalize()
......
...@@ -43,5 +43,6 @@ from .strided_slice_python import strided_slice_python, strided_set_python ...@@ -43,5 +43,6 @@ from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask from .sequence_mask_python import sequence_mask
from .pool3d_python import pool3d_ncdhw_python
from .pool_grad_python import pool_grad_nchw from .pool_grad_python import pool_grad_nchw
from .one_hot import one_hot from .one_hot import one_hot
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, unused-variable
"""max_pool3d and avg_pool3d in python"""
import math
import numpy as np
def pool3d_ncdhw_python(np_data, kernel,
strides, padding,
out_shape, pool_type,
count_include_pad=True,
ceil_mode=False, dtype="float32"):
"""baseline for max_pool3d and avg_pool3d, default layout is "NCDHW"""
in_n, in_c, in_d, in_h, in_w = in_shape = np_data.shape
k_d, k_h, k_w = kernel
s_d, s_h, s_w = strides
pf, pt, pl, pk, pb, pr = padding
if ceil_mode:
assert out_shape[2] == int(math.ceil(float(in_shape[2] - k_d + pf + pk) / s_d) + 1)
assert out_shape[3] == int(math.ceil(float(in_shape[3] - k_h + pt + pb) / s_h) + 1)
assert out_shape[4] == int(math.ceil(float(in_shape[4] - k_w + pl + pr) / s_w) + 1)
else:
assert out_shape[2] == int(math.floor(float(in_shape[2] - k_d + pf + pk) / s_d) + 1)
assert out_shape[3] == int(math.floor(float(in_shape[3] - k_h + pt + pb) / s_h) + 1)
assert out_shape[4] == int(math.floor(float(in_shape[4] - k_w + pl + pr) / s_w) + 1)
pad_np = np.zeros(shape=(in_n, in_c,
in_d + pf + pk,
in_h + pt + pb,
in_w + pl + pr)).astype(dtype)
no_zero = (range(in_n),
range(in_c),
(range(pf, in_d + pf)),
(range(pt, in_h + pt)),
(range(pl, in_w + pl)))
pad_np[np.ix_(*no_zero)] = np_data
ret_np = np.zeros(shape=out_shape).astype(dtype)
if pool_type == 'avg':
for k in range(out_shape[2]):
for i in range(out_shape[3]):
for j in range(out_shape[4]):
if count_include_pad:
ret_np[:, :, k, i, j] = \
np.mean(pad_np[:, :, k * s_d: k * s_d + k_d,
i * s_h: i * s_h + k_h,
j * s_w: j * s_w + k_w], axis=(2, 3, 4))
else:
pad_count = np.sum(pad_np[:, :,
k * s_d: k * s_d + k_d,
i * s_h: i * s_h + k_h,
j * s_w: j * s_w + k_w] > 0, axis=(2, 3, 4))
ret_np[:, :, k, i, j] = np.sum(pad_np[:, :,
k * s_d: k * s_d + k_d,
i * s_h: i * s_h + k_h,
j * s_w: j * s_w + k_w],
axis=(2, 3, 4)) / np.maximum(pad_count, 1)
elif pool_type == 'max':
for k in range(out_shape[2]):
for i in range(out_shape[3]):
for j in range(out_shape[4]):
ret_np[:, :, k, i, j] = np.max(
pad_np[:, :, k * s_d: k * s_d + k_d,
i * s_h: i * s_h + k_h,
j * s_w: j * s_w + k_w], axis=(2, 3, 4))
else:
raise ValueError("pool type {} is not supported".format(pool_type))
ret_np = np.maximum(ret_np, 0.0)
return ret_np
...@@ -15,13 +15,12 @@ ...@@ -15,13 +15,12 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Test code for pooling""" """Test code for pooling"""
import math
import numpy as np import numpy as np
import tvm import tvm
import topi import topi
import topi.testing import topi.testing
import math
from topi.util import get_const_tuple from topi.util import get_const_tuple
from common import get_all_backend from common import get_all_backend
def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
...@@ -264,57 +263,25 @@ def test_adaptive_pool(): ...@@ -264,57 +263,25 @@ def test_adaptive_pool():
verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max") verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max")
verify_adaptive_pool((1, 5, 46, 97), (4, 96), "avg") verify_adaptive_pool((1, 5, 46, 97), (4, 96), "avg")
def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type,
iz = iw = ih ceil_mode, count_include_pad=True, layout='NCDHW'):
kz = kw = kh id = iw = ih
sz = sw = sh kd = kw = kh
pf, pt, pl, pk, pb, pr = padding sd = sw = sh
layout = "NCDHW" input_shape = (n, ic, id, ih, iw)
A = tvm.placeholder((n, ic, iz, ih, iw), name='A') kernel = [kd, kh, kw]
B = topi.nn.pool3d(A, kernel=[kz, kh, kw], stride=[sz, sh, sw], padding=padding, stride = [sd, sh, sw]
A = tvm.placeholder(input_shape, name='A')
B = topi.nn.pool3d(A, kernel=kernel, stride=stride, padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode, pool_type=pool_type, ceil_mode=ceil_mode,
layout="NCDHW", count_include_pad=count_include_pad) layout=layout, count_include_pad=count_include_pad)
B = topi.nn.relu(B) B = topi.nn.relu(B)
dtype = A.dtype dtype = A.dtype
output_shape = [int(i) for i in B.shape]
bshape = get_const_tuple(B.shape) input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype)
ashape = get_const_tuple(A.shape) ref_np = topi.testing.pool3d_ncdhw_python(input_np, kernel, stride, padding,
if ceil_mode: output_shape, pool_type, count_include_pad, ceil_mode)
assert bshape[2] == int(math.ceil(float(ashape[2] - kz + pf + pk) / sz) + 1)
assert bshape[3] == int(math.ceil(float(ashape[3] - kh + pt + pb) / sh) + 1)
assert bshape[4] == int(math.ceil(float(ashape[4] - kw + pl + pr) / sw) + 1)
else:
assert bshape[2] == int(math.floor(float(ashape[2] - kz + pf + pk) / sz) + 1)
assert bshape[3] == int(math.floor(float(ashape[3] - kh + pt + pb) / sh) + 1)
assert bshape[4] == int(math.floor(float(ashape[4] - kw + pl + pr) / sw) + 1)
a_np = np.random.uniform(low=0.001, size=(n, ic, iz, ih, iw)).astype(dtype)
pad_np = np.zeros(shape=(n, ic, iz+pf+pk, ih+pt+pb, iw+pl+pr)).astype(dtype)
no_zero = (range(n), range(ic), (range(pf, iz+pf)), (range(pt, ih+pt)), (range(pl, iw+pl)))
pad_np[np.ix_(*no_zero)] = a_np
_, oc, oz, oh, ow = get_const_tuple(B.shape)
b_np = np.zeros(shape=(n, oc, oz, oh, ow)).astype(dtype)
if pool_type == 'avg':
for k in range(oz):
for i in range(oh):
for j in range(ow):
if count_include_pad:
b_np[:,:,k,i,j] = np.mean( \
pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4))
else:
pad_count = np.sum( \
pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3,4))
b_np[:,:,k,i,j] = np.sum(pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], \
axis=(2,3, 4)) / np.maximum(pad_count, 1)
elif pool_type =='max':
for k in range(oz):
for i in range(oh):
for j in range(ow):
b_np[:,:,k,i,j] = np.max( \
pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4))
b_np = np.maximum(b_np, 0.0)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -325,11 +292,11 @@ def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_includ ...@@ -325,11 +292,11 @@ def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_includ
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_pool(B, layout) s = topi.generic.schedule_pool(B, layout)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(input_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
f(a, b) f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5)
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(device)
...@@ -353,7 +320,7 @@ def test_pool3d(): ...@@ -353,7 +320,7 @@ def test_pool3d():
if __name__ == "__main__": if __name__ == "__main__":
test_pool() test_pool()
test_pool3d()
test_pool_grad() test_pool_grad()
test_global_pool() test_global_pool()
test_adaptive_pool() test_adaptive_pool()
test_pool3d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment