Commit 8d83da6b by 戚海涛 Committed by Tianqi Chen

fix topi.nn.global_pool layout="NHWC" (#4656)

* Update topi.cc

fix topi.nn.global_pool layout="NHWC"

* add topi.nn.global_pool layout=NHWC test
parent 07b45d95
...@@ -527,7 +527,7 @@ TVM_REGISTER_GLOBAL("topi.nn.pool_grad") ...@@ -527,7 +527,7 @@ TVM_REGISTER_GLOBAL("topi.nn.pool_grad")
TVM_REGISTER_GLOBAL("topi.nn.global_pool") TVM_REGISTER_GLOBAL("topi.nn.global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::global_pool(args[0], *rv = nn::global_pool(args[0],
static_cast<nn::PoolType>(static_cast<int>(args[1]))); static_cast<nn::PoolType>(static_cast<int>(args[1])), args[2]);
}); });
TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool") TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool")
......
...@@ -178,16 +178,20 @@ def test_pool_grad(): ...@@ -178,16 +178,20 @@ def test_pool_grad():
verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True) verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True)
def verify_global_pool(n, c, h, w, pool_type): def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'):
assert layout in ["NCHW", "NHWC"]
A = tvm.placeholder((n, c, h, w), name='A') A = tvm.placeholder((n, c, h, w), name='A')
B = topi.nn.global_pool(A, pool_type=pool_type) B = topi.nn.global_pool(A, pool_type=pool_type, layout=layout)
B = topi.nn.relu(B) B = topi.nn.relu(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
axis = (layout.find('H'), layout.find('W'))
if pool_type == 'avg': if pool_type == 'avg':
b_np = np.mean(a_np, axis=(2,3), keepdims=True) b_np = np.mean(a_np, axis=axis, keepdims=True)
elif pool_type =='max': elif pool_type =='max':
b_np = np.max(a_np, axis=(2,3), keepdims=True) b_np = np.max(a_np, axis=axis, keepdims=True)
b_np = np.maximum(b_np, 0.0) b_np = np.maximum(b_np, 0.0)
def check_device(device): def check_device(device):
...@@ -212,6 +216,10 @@ def test_global_pool(): ...@@ -212,6 +216,10 @@ def test_global_pool():
verify_global_pool(4, 1024, 7, 7, 'avg') verify_global_pool(4, 1024, 7, 7, 'avg')
verify_global_pool(1, 1024, 7, 7, 'max') verify_global_pool(1, 1024, 7, 7, 'max')
verify_global_pool(4, 1024, 7, 7, 'max') verify_global_pool(4, 1024, 7, 7, 'max')
verify_global_pool(1, 1024, 7, 7, 'avg', 'NHWC')
verify_global_pool(4, 1024, 7, 7, 'avg', 'NHWC')
verify_global_pool(1, 1024, 7, 7, 'max', 'NHWC')
verify_global_pool(4, 1024, 7, 7, 'max', 'NHWC')
def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"): def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
def start_index(index, odim, idim): def start_index(index, odim, idim):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment