Commit 75d53777 by Yuwei HU Committed by Tianqi Chen

add pool (#478)

parent f863bfdc
......@@ -11,4 +11,4 @@ from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_global_pool
from .pooling import schedule_pool, schedule_global_pool
......@@ -56,7 +56,7 @@ def schedule_global_pool(outs):
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule global_pool
elif 'global_pool' in OP.tag:
elif OP.tag.startswith('global_pool'):
Pool = OP.output(0)
_schedule(Pool)
else:
......@@ -64,3 +64,57 @@ def schedule_global_pool(outs):
traverse(outs[0].op)
return s
def schedule_pool(outs):
"""Schedule for pool.
Parameters
----------
outs: Array of Tensor
The computation graph description of pool
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for pool.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Pool):
s[PaddedInput].compute_inline()
num_thread = 512
if Pool.op in s.outputs:
Out = Pool
OL = s.cache_write(Pool, "local")
else:
Out = outs[0].op.output(0)
s[Pool].set_scope("local")
fused = s[Out].fuse(*s[Out].op.axis)
bx, tx = s[Out].split(fused, factor=num_thread)
s[Out].bind(bx, tvm.thread_axis("blockIdx.x"))
s[Out].bind(tx, tvm.thread_axis("threadIdx.x"))
if Pool.op in s.outputs:
s[OL].compute_at(s[Out], tx)
else:
s[Pool].compute_at(s[Out], tx)
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
PaddedInput = OP.input_tensors[0]
Pool = OP.output(0)
_schedule(PaddedInput, Pool)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
return s
......@@ -6,84 +6,103 @@ from .util import get_pad_tuple
from .. import util
from .. import tag
def max_pool(data, kernel, stride, padding):
"""Perform max pooling on the data
def global_pool(data, pool_type):
"""Perform global pooling on the data
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
kernel : list/tuple of two ints
Kernel size, or [kernel_height, kernel_width]
stride : list/tuple of two ints
Stride size, or [stride_height, stride_width]
paddding : list/tuple of two ints
Pad size, or [pad_height, pad_width]
pool_type : str
Pool type, 'max' or 'avg'
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, out_height, out_width]
4-D with shape [batch, channel, 1, 1]
"""
assert len(data.shape) == 4, "only support 4-dim pooling"
assert len(stride) == 2, "only support 2-dim stride"
kernel_height, kernel_width = kernel
stride_height, stride_width = stride
batch, channel, height, width = data.shape
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_height, kernel_width))
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(data, pad_before, pad_after, name="pad_temp",
pad_value=tvm.min_value("float32"))
out_height = util.simplify((height - kernel_height + pad_top + pad_down) // stride_height + 1)
out_width = util.simplify((width - kernel_width + pad_left + pad_right) // stride_width + 1)
dheight = tvm.reduce_axis((0, kernel_height))
dwidth = tvm.reduce_axis((0, kernel_width))
dheight = tvm.reduce_axis((0, height))
dwidth = tvm.reduce_axis((0, width))
return tvm.compute(
(batch, channel, out_height, out_width),
lambda i, c, h, w:
tvm.max(temp[i, c, h*stride_height+dheight, w*stride_width+dwidth], axis=[dheight, dwidth]),
tag="max_pool")
if pool_type == 'max':
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tvm.max(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
tag="global_pool_max")
elif pool_type == 'avg':
tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
tag="global_pool_sum")
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tsum[n, c, h, w] / (height*width), \
tag=tag.ELEMWISE)
else:
raise ValueError("Pool type should be 'avg' or 'max'.")
def global_pool(data, pool_type):
"""Perform global pooling on the data
def pool(data, kernel, stride, padding, pool_type):
"""Perform pooling on the data
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
kernel : list/tuple of two ints
Kernel size, [kernel_height, kernel_width]
stride : list/tuple of two ints
Stride size, [stride_height, stride_width]
paddding : list/tuple of two ints
Pad size, [pad_height, pad_width]
pool_type : str
Pool type, 'max' or 'avg'
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, 1, 1]
4-D with shape [batch, channel, out_height, out_width]
"""
assert len(data.shape) == 4, "only support 4-dim pooling"
assert len(stride) == 2, "only support 2-dim stride"
kernel_height, kernel_width = kernel
stride_height, stride_width = stride
batch, channel, height, width = data.shape
dheight = tvm.reduce_axis((0, height))
dwidth = tvm.reduce_axis((0, width))
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_height, kernel_width))
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
out_height = util.simplify((height - kernel_height + pad_top + pad_down) // stride_height + 1)
out_width = util.simplify((width - kernel_width + pad_left + pad_right) // stride_width + 1)
dheight = tvm.reduce_axis((0, kernel_height))
dwidth = tvm.reduce_axis((0, kernel_width))
if pool_type == 'max':
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tvm.max(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
tag="global_pool_max")
temp = pad(data, pad_before, pad_after, name="pad_temp", \
pad_value=tvm.min_value(data.dtype))
return tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: \
tvm.max(temp[n, c, h*stride_height+dheight, w*stride_width+dwidth], \
axis=[dheight, dwidth]), \
tag="pool_max")
elif pool_type == 'avg':
tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
tag="global_pool_sum")
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tsum[n, c, h, w] / (height*width), \
temp = pad(data, pad_before, pad_after, name="pad_temp", \
pad_value=tvm.const(0.).astype(data.dtype))
tsum = tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: \
tvm.sum(temp[n, c, h*stride_height+dheight, w*stride_width+dwidth], \
axis=[dheight, dwidth]), \
tag="pool_avg")
return tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: \
tsum[n, c, h, w] / (kernel_height*kernel_width), \
tag=tag.ELEMWISE)
else:
raise ValueError("Pool type should be 'avg' or 'max'.")
......@@ -4,6 +4,55 @@ import tvm
import topi
from topi.util import get_const_tuple
def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
iw = ih
kw = kh
sw = sh
ph, pw = padding
A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type)
B = topi.nn.relu(B)
s = topi.cuda.schedule_pool(B)
dtype = A.dtype
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype)
no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
pad_np[np.ix_(*no_zero)] = a_np
_, oc, oh, ow = get_const_tuple(B.shape)
b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)
if pool_type == 'avg':
for i in range(oh):
for j in range(ow):
b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
elif pool_type =='max':
for i in range(oh):
for j in range(ow):
b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
b_np = np.maximum(b_np, 0.0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
def test_pool():
verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg')
verify_pool(1, 256, 31, 3, 3, [1, 1], 'avg')
verify_pool(1, 256, 32, 2, 2, [0, 0], 'max')
verify_pool(1, 256, 31, 3, 3, [1, 1], 'max')
def verify_global_pool(n, c, h, w, pool_type):
A = tvm.placeholder((n, c, h, w), name='A')
B = topi.nn.global_pool(A, pool_type=pool_type)
......@@ -24,7 +73,7 @@ def verify_global_pool(n, c, h, w, pool_type):
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
f = tvm.build(s, [A, B], device, name="global_avg_pool")
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
......@@ -39,4 +88,5 @@ def test_global_pool():
if __name__ == "__main__":
test_pool()
test_global_pool()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment