Commit 203b8188 by Yuwei HU Committed by Tianqi Chen

[TOPI] migrate global_avg_pool, fully_connected (#472)

* migrate global_avg_pool, fully_connected

* fix pylint

* enable fusion of pooling schedule

* rename fc->dense, enable fusion

* improve dense schedule

* unified global pool
parent cd623f43
......@@ -10,3 +10,5 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_global_pool
# pylint: disable=invalid-name, unused-variable
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import tvm
from .. import tag
def schedule_dense(outs):
"""Schedule for dense operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of dense
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for dense.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(Dense):
num_thread = 64
k = Dense.op.reduce_axis[0]
ko, kf = s[Dense].split(k, factor=num_thread)
DenseF = s.rfactor(Dense, kf)
if Dense.op in s.outputs:
Out = Dense
else:
Out = outs[0].op.output(0)
s[Dense].compute_at(s[Out], s[Out].op.axis[1])
s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))
tx = s[Dense].op.reduce_axis[0]
thread_x = tvm.thread_axis("threadIdx.x")
s[Dense].bind(tx, thread_x)
s[DenseF].compute_at(s[Dense], tx)
s[Dense].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0))
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
Dense = OP.output(0)
_schedule(Dense)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
return s
# pylint: disable=invalid-name, unused-variable
"""Schedule for pooling operators"""
import tvm
from .. import tag
def schedule_global_pool(outs):
"""Schedule for global_pool.
Parameters
----------
outs: Array of Tensor
The computation graph description of global_pool
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for global_pool.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(Pool):
num_thread = 8
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
if Pool.op in s.outputs:
Out = Pool
OL = s.cache_write(Pool, "local")
else:
Out = outs[0].op.output(0)
s[Pool].set_scope("local")
i, c, h, w = s[Out].op.axis
dh, dw = s[Pool].op.reduce_axis
fuse_index = s[Pool].fuse(dw, dh)
s[Pool].unroll(fuse_index)
by, ty = s[Out].split(i, factor=num_thread)
bx, tx = s[Out].split(c, factor=num_thread)
s[Out].reorder(by, bx, ty, tx)
s[Out].bind(ty, thread_y)
s[Out].bind(tx, thread_x)
s[Out].bind(by, block_y)
s[Out].bind(bx, block_x)
if Pool.op in s.outputs:
s[OL].compute_at(s[Out], tx)
else:
s[Pool].compute_at(s[Out], tx)
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule global_pool
elif 'global_pool' in OP.tag:
Pool = OP.output(0)
_schedule(Pool)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
return s
......@@ -8,7 +8,7 @@ from .depthwise_convolution import *
from .elemwise import *
from .dilate import *
from .flatten import *
from .fully_connected import *
from .dense import *
from .mapping import *
from .pooling import *
from .softmax import *
"""TVM operator fully connected compute."""
from __future__ import absolute_import
import tvm
from .. import tag
@tvm.tag_scope(tag='fully_connected')
def fully_connected(data, weight):
"""Matrix multiplication
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim fully_connected"
batch, in_dim = data.shape
out_dim, _ = weight.shape
k = tvm.reduce_axis((0, in_dim), name='k')
return tvm.compute((batch, out_dim), lambda i, j: \
tvm.sum(data[i][k] * weight[j][k], axis=k))
@tvm.tag_scope(tag='fully_connected_with_bias')
def fully_connected_with_bias(data, weight, bias):
def dense(data, weight, bias, use_bias=True):
"""Applies a linear transformation: :math:`Y = XW^T + b`.
Parameters
......@@ -44,19 +18,24 @@ def fully_connected_with_bias(data, weight, bias):
bias : tvm.Tensor
1-D with shape [out_dim]
use_bias : bool, optional, default=True
Whether to use bias parameter
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim fully_connected"
assert len(data.shape) == 2 and len(weight.shape) == 2 and len(bias.shape) == 1, \
"only support 2-dim fully_connected"
"only support 2-dim dense"
batch, in_dim = data.shape
out_dim, _ = weight.shape
k = tvm.reduce_axis((0, in_dim), name='k')
matmul = tvm.compute((batch, out_dim), lambda i, j: \
tvm.sum(data[i, k] * weight[j, k], axis=k))
return tvm.compute((batch, out_dim), lambda i, j: \
matmul[i, j] + bias[j])
matmul = tvm.compute((batch, out_dim), \
lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k), \
tag='dense')
if not use_bias:
return matmul
return tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST)
......@@ -4,6 +4,7 @@ import tvm
from .pad import pad
from .util import get_pad_tuple
from .. import util
from .. import tag
def max_pool(data, kernel, stride, padding):
"""Perform max pooling on the data
......@@ -51,15 +52,17 @@ def max_pool(data, kernel, stride, padding):
tag="max_pool")
@tvm.tag_scope(tag='global_avg_pool')
def global_avg_pool(data):
"""Perform global average pooling on the data
def global_pool(data, pool_type):
"""Perform global pooling on the data
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
pool_type : str
Pool type, 'max' or 'avg'
Returns
-------
output : tvm.Tensor
......@@ -71,7 +74,16 @@ def global_avg_pool(data):
dheight = tvm.reduce_axis((0, height))
dwidth = tvm.reduce_axis((0, width))
tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]))
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tsum[n, c, h, w] / (height*width))
if pool_type == 'max':
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tvm.max(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
tag="global_pool_max")
elif pool_type == 'avg':
tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
tag="global_pool_sum")
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tsum[n, c, h, w] / (height*width), \
tag=tag.ELEMWISE)
else:
raise ValueError("Pool type should be 'avg' or 'max'.")
"""Test code for dense operator"""
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
def verify_dense(batch, in_dim, out_dim, use_bias=True):
A = tvm.placeholder((batch, in_dim), name='A')
B = tvm.placeholder((out_dim, in_dim), name='B')
C = tvm.placeholder((out_dim,), name='C')
D = topi.nn.dense(A, B, C, use_bias=use_bias)
D = topi.nn.relu(D)
s = topi.cuda.schedule_dense(D)
dtype = A.dtype
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_dense")
def get_ref_data():
a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype)
b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype)
c_np = np.random.uniform(size=(out_dim,)).astype(dtype)
if use_bias:
d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0)
else:
d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
return (a_np, b_np, c_np, d_np)
# get the test data
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B, C, D], device, name="dense")
f(a, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
def test_dense():
verify_dense(1, 1024, 1000, use_bias=True)
verify_dense(1, 1024, 1000, use_bias=False)
if __name__ == "__main__":
test_dense()
"""Test code for pooling"""
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def verify_global_pool(n, c, h, w, pool_type):
A = tvm.placeholder((n, c, h, w), name='A')
B = topi.nn.global_pool(A, pool_type=pool_type)
B = topi.nn.relu(B)
s = topi.cuda.schedule_global_pool(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
if pool_type == 'avg':
b_np = np.mean(a_np, axis=(2,3), keepdims=True)
elif pool_type =='max':
b_np = np.max(a_np, axis=(2,3), keepdims=True)
b_np = np.maximum(b_np, 0.0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
f = tvm.build(s, [A, B], device, name="global_avg_pool")
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
def test_global_pool():
verify_global_pool(1, 1024, 7, 7, 'avg')
verify_global_pool(4, 1024, 7, 7, 'avg')
verify_global_pool(1, 1024, 7, 7, 'max')
verify_global_pool(4, 1024, 7, 7, 'max')
if __name__ == "__main__":
test_global_pool()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment