Unverified Commit 0a1e1601 by SXM-inspur Committed by GitHub

[Topi][Cuda]Optimizations of global_ave_pool for NHWC layout (#5450)

* Optimizations of global_ave_pool for NHWC layout

* Optimize the code format to pass inspection of pylint

Co-authored-by: Shawn-Inspur <wushaohua@inspur.com>
parent 702db6f9
...@@ -58,7 +58,7 @@ def schedule_pool_grad_cuda(attrs, outs, target): ...@@ -58,7 +58,7 @@ def schedule_pool_grad_cuda(attrs, outs, target):
def schedule_adaptive_pool_cuda(attrs, outs, target): def schedule_adaptive_pool_cuda(attrs, outs, target):
"""schedule adaptive pooling ops for cuda""" """schedule adaptive pooling ops for cuda"""
with target: with target:
return topi.cuda.schedule_adaptive_pool(outs) return topi.cuda.schedule_adaptive_pool(outs, attrs.layout)
@softmax_strategy.register(["cuda", "gpu"]) @softmax_strategy.register(["cuda", "gpu"])
def softmax_strategy_cuda(attrs, inputs, out_type, target): def softmax_strategy_cuda(attrs, inputs, out_type, target):
......
...@@ -22,7 +22,7 @@ from .. import tag ...@@ -22,7 +22,7 @@ from .. import tag
from ..util import traverse_inline from ..util import traverse_inline
def schedule_adaptive_pool(outs): def schedule_adaptive_pool(outs, layout='NCHW'):
"""Schedule for adaptive_pool. """Schedule for adaptive_pool.
Parameters Parameters
...@@ -51,8 +51,12 @@ def schedule_adaptive_pool(outs): ...@@ -51,8 +51,12 @@ def schedule_adaptive_pool(outs):
else: else:
Out = outs[0].op.output(0) Out = outs[0].op.output(0)
s[Pool].set_scope("local") s[Pool].set_scope("local")
by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread) by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread)
bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread) if layout == 'NHWC':
bx, tx = s[Out].split(s[Out].op.axis[3], factor=num_thread)
else:
bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread)
s[Out].reorder(by, bx, ty, tx) s[Out].reorder(by, bx, ty, tx)
s[Out].bind(ty, thread_y) s[Out].bind(ty, thread_y)
s[Out].bind(tx, thread_x) s[Out].bind(tx, thread_x)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument
"""Test code for pooling""" """Test code for pooling"""
import math import math
import numpy as np import numpy as np
...@@ -44,6 +45,7 @@ _pool_grad_schedule = { ...@@ -44,6 +45,7 @@ _pool_grad_schedule = {
} }
def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
"""verify function of pool"""
iw = ih iw = ih
kw = kh kw = kh
sw = sh sw = sh
...@@ -76,15 +78,17 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_ ...@@ -76,15 +78,17 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
for i in range(oh): for i in range(oh):
for j in range(ow): for j in range(ow):
if count_include_pad: if count_include_pad:
b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) b_np[:, :, i, j] = \
np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3))
else: else:
pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3)) pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2, 3))
b_np[:,:,i,j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) / np.maximum(pad_count, 1) b_np[:, :, i, j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) \
/ np.maximum(pad_count, 1)
elif pool_type =='max': elif pool_type == 'max':
for i in range(oh): for i in range(oh):
for j in range(ow): for j in range(ow):
b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) b_np[:, :, i, j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3))
b_np = np.maximum(b_np, 0.0) b_np = np.maximum(b_np, 0.0)
def check_device(device): def check_device(device):
...@@ -108,11 +112,11 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_ ...@@ -108,11 +112,11 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True, def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True,
add_relu=False): add_relu=False):
"""verify function of pool_grad"""
iw = ih iw = ih
kw = kh kw = kh
sw = sh sw = sh
pt, pl, pb, pr = padding pt, pl, pb, pr = padding
layout = "NCHW"
A = te.placeholder((n, ic, ih, iw), name='A') A = te.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode, pool_type=pool_type, ceil_mode=ceil_mode,
...@@ -164,6 +168,7 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc ...@@ -164,6 +168,7 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc
check_device(device) check_device(device)
def test_pool(): def test_pool():
"""test cases of pool"""
verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True) verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True) verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False) verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False)
...@@ -179,6 +184,7 @@ def test_pool(): ...@@ -179,6 +184,7 @@ def test_pool():
verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True) verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)
def test_pool_grad(): def test_pool_grad():
"""test cases of pool_grad"""
verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False) verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False)
verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True) verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True) verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
...@@ -200,10 +206,10 @@ def test_pool_grad(): ...@@ -200,10 +206,10 @@ def test_pool_grad():
verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True) verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True)
def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'): def verify_global_pool(dshape, pool_type, layout='NCHW'):
"""verify function of global_pool"""
assert layout in ["NCHW", "NHWC"] assert layout in ["NCHW", "NHWC"]
A = te.placeholder((n, c, h, w), name='A') A = te.placeholder(shape=dshape, name='A')
B = topi.nn.global_pool(A, pool_type=pool_type, layout=layout) B = topi.nn.global_pool(A, pool_type=pool_type, layout=layout)
B = topi.nn.relu(B) B = topi.nn.relu(B)
...@@ -212,7 +218,7 @@ def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'): ...@@ -212,7 +218,7 @@ def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'):
axis = (layout.find('H'), layout.find('W')) axis = (layout.find('H'), layout.find('W'))
if pool_type == 'avg': if pool_type == 'avg':
b_np = np.mean(a_np, axis=axis, keepdims=True) b_np = np.mean(a_np, axis=axis, keepdims=True)
elif pool_type =='max': elif pool_type == 'max':
b_np = np.max(a_np, axis=axis, keepdims=True) b_np = np.max(a_np, axis=axis, keepdims=True)
b_np = np.maximum(b_np, 0.0) b_np = np.maximum(b_np, 0.0)
...@@ -224,7 +230,10 @@ def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'): ...@@ -224,7 +230,10 @@ def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'):
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s_func = topi.testing.dispatch(device, _adaptive_pool_schedule) s_func = topi.testing.dispatch(device, _adaptive_pool_schedule)
s = s_func(B) if device == "cuda":
s = s_func(B, layout)
else:
s = s_func(B)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
...@@ -235,17 +244,19 @@ def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'): ...@@ -235,17 +244,19 @@ def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'):
check_device(device) check_device(device)
def test_global_pool(): def test_global_pool():
verify_global_pool(1, 1024, 7, 7, 'avg') """test cases of global_pool"""
verify_global_pool(4, 1024, 7, 7, 'avg') verify_global_pool((1, 1024, 7, 7), 'avg')
verify_global_pool(1, 1024, 7, 7, 'max') verify_global_pool((4, 1024, 7, 7), 'avg')
verify_global_pool(4, 1024, 7, 7, 'max') verify_global_pool((1, 1024, 7, 7), 'max')
verify_global_pool(1, 1024, 7, 7, 'avg', 'NHWC') verify_global_pool((4, 1024, 7, 7), 'max')
verify_global_pool(4, 1024, 7, 7, 'avg', 'NHWC') verify_global_pool((1, 7, 7, 1024), 'avg', 'NHWC')
verify_global_pool(1, 1024, 7, 7, 'max', 'NHWC') verify_global_pool((4, 7, 7, 1024), 'avg', 'NHWC')
verify_global_pool(4, 1024, 7, 7, 'max', 'NHWC') verify_global_pool((1, 7, 7, 1024), 'max', 'NHWC')
verify_global_pool((4, 7, 7, 1024), 'max', 'NHWC')
def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"): def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
"""verify function of adaptive_pool"""
np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype) np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
np_out = topi.testing.adaptive_pool(np_data, out_size, pool_type, layout) np_out = topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)
oshape = np_out.shape oshape = np_out.shape
...@@ -265,7 +276,10 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="floa ...@@ -265,7 +276,10 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="floa
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s_func = topi.testing.dispatch(device, _adaptive_pool_schedule) s_func = topi.testing.dispatch(device, _adaptive_pool_schedule)
s = s_func(out) if device == "cuda":
s = s_func(out, layout)
else:
s = s_func(out)
a = tvm.nd.array(np_data, ctx) a = tvm.nd.array(np_data, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(oshape), dtype=out.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(oshape), dtype=out.dtype), ctx)
f = tvm.build(s, [data, out], device) f = tvm.build(s, [data, out], device)
...@@ -277,6 +291,7 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="floa ...@@ -277,6 +291,7 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="floa
def test_adaptive_pool(): def test_adaptive_pool():
"""test cases of adaptive_pool"""
verify_adaptive_pool((1, 3, 224, 224), (1, 1), "max") verify_adaptive_pool((1, 3, 224, 224), (1, 1), "max")
verify_adaptive_pool((1, 3, 224, 224), (1, 1), "avg") verify_adaptive_pool((1, 3, 224, 224), (1, 1), "avg")
verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max") verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max")
...@@ -295,6 +310,7 @@ def test_adaptive_pool(): ...@@ -295,6 +310,7 @@ def test_adaptive_pool():
def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type,
ceil_mode, count_include_pad=True, layout='NCDHW'): ceil_mode, count_include_pad=True, layout='NCDHW'):
"""verify function of pool3d"""
id = iw = ih id = iw = ih
kd = kw = kh kd = kw = kh
sd = sw = sh sd = sw = sh
...@@ -334,6 +350,7 @@ def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ...@@ -334,6 +350,7 @@ def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type,
def test_pool3d(): def test_pool3d():
"""test cases of pool3d"""
verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'avg', False, True) verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'avg', False, True)
verify_pool3d(1, 256, 31, 3, 3, [1, 1, 2, 2, 2, 1], 'avg', False, True) verify_pool3d(1, 256, 31, 3, 3, [1, 1, 2, 2, 2, 1], 'avg', False, True)
verify_pool3d(1, 256, 32, 2, 2, [1, 1, 2, 2, 2, 1], 'avg', False, False) verify_pool3d(1, 256, 32, 2, 2, [1, 1, 2, 2, 2, 1], 'avg', False, False)
...@@ -351,6 +368,7 @@ def test_pool3d(): ...@@ -351,6 +368,7 @@ def test_pool3d():
def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type, def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type,
ceil_mode, count_include_pad=True, layout='NCW'): ceil_mode, count_include_pad=True, layout='NCW'):
"""verify function of pool1d"""
input_shape = (n, ic, iw) input_shape = (n, ic, iw)
kernel = [kw] kernel = [kw]
stride = [sw] stride = [sw]
...@@ -387,6 +405,7 @@ def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type, ...@@ -387,6 +405,7 @@ def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type,
def test_pool1d(): def test_pool1d():
"""test cases of pool1d"""
verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'avg', False, True) verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'avg', False, True)
verify_pool1d(1, 256, 31, 3, 3, [1, 2], 'avg', False, True) verify_pool1d(1, 256, 31, 3, 3, [1, 2], 'avg', False, True)
verify_pool1d(1, 256, 32, 2, 2, [1, 2], 'avg', False, False) verify_pool1d(1, 256, 32, 2, 2, [1, 2], 'avg', False, False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment