Unverified Commit f1ede9a9 by Wuwei Lin Committed by GitHub

[TOPI][CUDA] Schedule for pool_grad (#3622)

* [TOPI][CUDA] Schedule for pool_grad

* Relay test

* Fix fused op

* doc

* Remove set scope local
parent 8e0aaa29
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
import topi
import topi.testing
from tvm import relay
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list
from tvm.relay.testing import run_infer_type
def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
x = relay.var("x", relay.TensorType(x_shape, "float32"))
y = tvm.relay.nn.max_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
ceil_mode=ceil_mode)
fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))
data = np.random.rand(*x_shape).astype("float32")
ph, pw = padding
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=pool_size, strides=strides,
padding=[ph, pw, ph, pw],
pool_type='max', ceil_mode=ceil_mode)
for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
def test_max_pool2d_grad():
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
ceil_mode=False)
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)
def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, count_include_pad):
x = relay.var("x", relay.TensorType(x_shape, "float32"))
y = tvm.relay.nn.avg_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
ceil_mode=ceil_mode, count_include_pad=count_include_pad)
fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))
data = np.random.rand(*x_shape).astype("float32")
ph, pw = padding
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=pool_size, strides=strides,
padding=[ph, pw, ph, pw],
pool_type='avg', ceil_mode=ceil_mode)
for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
def test_avg_pool2d_grad():
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
ceil_mode=False, count_include_pad=True)
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1),
ceil_mode=False, count_include_pad=False)
if __name__ == "__main__":
test_max_pool2d_grad()
test_avg_pool2d_grad()
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import tvm import tvm
from .. import tag from .. import tag
from .. import generic from .. import generic
from ..util import traverse_inline
...@@ -150,3 +151,52 @@ def schedule_pool(outs, layout): ...@@ -150,3 +151,52 @@ def schedule_pool(outs, layout):
traverse(outs[0].op) traverse(outs[0].op)
return s return s
@generic.schedule_pool_grad.register(['cuda', 'gpu'])
def schedule_pool_grad_cuda(outs):
"""Schedule for pool_grad on CUDA
Parameters
----------
outs: Array of Tensor
The computation graph description of pool_grad
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for pool_grad.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule_pool_grad(op):
if op in s.outputs:
out = op
else:
out = outs[0].op.output(0)
fused = s[out].fuse(*s[out].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = s[out].split(fused, factor=num_thread)
s[out].bind(bx, tvm.thread_axis("blockIdx.x"))
s[out].bind(tx, tvm.thread_axis("threadIdx.x"))
if tag.COMM_REDUCE_IDX in op.input_tensors[0].op.tag:
max_pool_index = op.input_tensors[0]
s[max_pool_index].compute_at(s[out], tx)
pool_input = max_pool_index.op.input_tensors[0]
if isinstance(pool_input.op, tvm.tensor.ComputeOp):
# handle padding
s[pool_input].compute_inline()
if op not in s.outputs:
s[op].compute_at(s[out], tx)
def _callback(op):
if op.tag.startswith('pool_grad'):
_schedule_pool_grad(op)
traverse_inline(s, outs[0].op, _callback)
return s
...@@ -86,7 +86,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_ ...@@ -86,7 +86,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(device)
def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True,
add_relu=False):
iw = ih iw = ih
kw = kh kw = kh
sw = sh sw = sh
...@@ -110,6 +111,8 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc ...@@ -110,6 +111,8 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc
PoolGrad = topi.nn.pool_grad(OutGrad, A, kernel=[kh, kw], stride=[sh, sw], padding=padding, PoolGrad = topi.nn.pool_grad(OutGrad, A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode, pool_type=pool_type, ceil_mode=ceil_mode,
layout="NCHW", count_include_pad=count_include_pad) layout="NCHW", count_include_pad=count_include_pad)
if add_relu:
PoolGrad = topi.nn.relu(PoolGrad)
a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype) a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
out_grad_np = np.random.uniform(low=0.001, size=bshape).astype(dtype) out_grad_np = np.random.uniform(low=0.001, size=bshape).astype(dtype)
...@@ -117,6 +120,8 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc ...@@ -117,6 +120,8 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc
strides=(sh, sw), padding=padding, strides=(sh, sw), padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode, pool_type=pool_type, ceil_mode=ceil_mode,
count_include_pad=count_include_pad) count_include_pad=count_include_pad)
if add_relu:
pool_grad_np = np.maximum(pool_grad_np, 0.)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -134,7 +139,7 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc ...@@ -134,7 +139,7 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc
f(a, out_grad, pool_grad) f(a, out_grad, pool_grad)
tvm.testing.assert_allclose(pool_grad.asnumpy(), pool_grad_np, rtol=1e-5) tvm.testing.assert_allclose(pool_grad.asnumpy(), pool_grad_np, rtol=1e-5)
for device in ['llvm']: # only support llvm for device in get_all_backend():
check_device(device) check_device(device)
def test_pool(): def test_pool():
...@@ -152,6 +157,7 @@ def test_pool(): ...@@ -152,6 +157,7 @@ def test_pool():
verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False) verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False)
verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True) verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)
def test_pool_grad():
verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False) verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False)
verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True) verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True) verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
...@@ -169,6 +175,9 @@ def test_pool(): ...@@ -169,6 +175,9 @@ def test_pool():
verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'max', False) verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'max', False)
verify_pool_grad(1, 256, 32, 1, 2, [1, 1, 1, 1], 'avg', False, False) verify_pool_grad(1, 256, 32, 1, 2, [1, 1, 1, 1], 'avg', False, False)
verify_pool_grad(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False, add_relu=True)
verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True)
def verify_global_pool(n, c, h, w, pool_type): def verify_global_pool(n, c, h, w, pool_type):
A = tvm.placeholder((n, c, h, w), name='A') A = tvm.placeholder((n, c, h, w), name='A')
...@@ -258,5 +267,6 @@ def test_adaptive_pool(): ...@@ -258,5 +267,6 @@ def test_adaptive_pool():
if __name__ == "__main__": if __name__ == "__main__":
test_pool() test_pool()
test_pool_grad()
test_global_pool() test_global_pool()
test_adaptive_pool() test_adaptive_pool()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment