Commit 88662130 by Tianqi Chen Committed by GitHub

[TOPI] Support ceil_mode in pooling (#593)

parent 2f2170f4
......@@ -44,7 +44,7 @@ def global_pool(data, pool_type):
raise ValueError("Pool type should be 'avg' or 'max'.")
def pool(data, kernel, stride, padding, pool_type):
def pool(data, kernel, stride, padding, pool_type, ceil_mode=False):
"""Perform pooling on the data
Parameters
......@@ -64,6 +64,9 @@ def pool(data, kernel, stride, padding, pool_type):
pool_type : str
Pool type, 'max' or 'avg'
ceil_mode : bool
Whether to use ceil when caculate output size.
Returns
-------
output : tvm.Tensor
......@@ -77,10 +80,18 @@ def pool(data, kernel, stride, padding, pool_type):
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_height, kernel_width))
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
if ceil_mode:
# Additional padding to ensure we do ceil instead of floor when divide stride.
pad_down += stride_height -1
pad_right += stride_width - 1
out_height = util.simplify((height - kernel_height + pad_top + pad_down) // stride_height + 1)
out_width = util.simplify((width - kernel_width + pad_left + pad_right) // stride_width + 1)
dheight = tvm.reduce_axis((0, kernel_height))
dwidth = tvm.reduce_axis((0, kernel_width))
......
......@@ -2,18 +2,30 @@
import numpy as np
import tvm
import topi
import math
from topi.util import get_const_tuple
def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
iw = ih
kw = kh
sw = sh
ph, pw = padding
A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type)
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode)
B = topi.nn.relu(B)
dtype = A.dtype
bshape = get_const_tuple(B.shape)
ashape = get_const_tuple(A.shape)
if ceil_mode:
assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1)
assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1)
else:
assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1)
assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1)
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype)
no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
......@@ -49,10 +61,12 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
check_device(device)
def test_pool():
verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg')
verify_pool(1, 256, 31, 3, 3, [1, 1], 'avg')
verify_pool(1, 256, 32, 2, 2, [0, 0], 'max')
verify_pool(1, 256, 31, 3, 3, [1, 1], 'max')
verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False)
verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False)
verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True)
def verify_global_pool(n, c, h, w, pool_type):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment