Commit f5d158b7 by tqchen Committed by Tianqi Chen

[TOP] Support ceil_mode

parent 67820feb
......@@ -136,8 +136,8 @@ def compute_max_pool2d(attrs, inputs, _):
layout = attrs["layout"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
assert not ceil_mode, "not support ceil_mode now"
return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='max')
return topi.nn.pool(inputs[0], pool_size, strides, padding,
pool_type='max', ceil_mode=ceil_mode)
@reg.register_schedule("max_pool2d")
def schedule_max_pool2d(_, outs, target):
......@@ -158,8 +158,8 @@ def compute_avg_pool2d(attrs, inputs, _):
layout = attrs["layout"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
assert not ceil_mode, "not support ceil_mode now"
return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='avg')
return topi.nn.pool(inputs[0], pool_size, strides, padding,
pool_type='avg', ceil_mode=ceil_mode)
@reg.register_schedule("avg_pool2d")
def schedule_avg_pool2d(_, outs, target):
......
......@@ -56,7 +56,8 @@ def test_grouped_conv2d():
def test_max_pool2d():
x = sym.Variable("x")
y = sym.max_pool2d(x, pool_size=(2,2), strides=(2,2), padding=(0,0), name="y")
y = sym.max_pool2d(x, pool_size=(2,2), strides=(2,2),
padding=(0,0), name="y", ceil_mode=True)
dtype = "float32"
dshape = (1, 3, 28, 28)
oshape = (1, 3, 14, 14)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment