Commit cd623f43 by Tianqi Chen Committed by GitHub

[TEST] rfactor+ewise, cite rfactor paper (#474)

* [TEST] rfactor+ewise, cite rfactor paper

* include all authors via abbrv

* [TOPI] Add transpose

* fix lint
parent 31fb14e4
......@@ -14,16 +14,23 @@ Index
topi.log
topi.sqrt
topi.sigmoid
topi.broadcast_to
topi.max
topi.sum
topi.min
topi.transpose
topi.expand_dims
topi.nn.relu
topi.nn.leaky_relu
topi.nn.dilate
topi.nn.conv2d_nchw
topi.nn.conv2d_hwcn
topi.nn.depthwise_conv2d_nchw
topi.nn.depthwise_conv2d_nhwc
topi.max
topi.sum
topi.min
topi.broadcast_to
topi.broadcast_add
topi.broadcast_sub
topi.broadcast_mul
topi.broadcast_div
**List of schedules**
......@@ -35,9 +42,8 @@ Index
topi.cuda.schedule_depthwise_conv2d_nchw
topi.cuda.schedule_depthwise_conv2d_nhwc
topi.cuda.schedule_reduce
topi.cuda.schedule_elemwise
topi.cuda.schedule_broadcast
topi.cuda.schedule_injective
topi
~~~~
......@@ -46,14 +52,22 @@ topi
.. autofunction:: topi.log
.. autofunction:: topi.sqrt
.. autofunction:: topi.sigmoid
.. autofunction:: topi.broadcast_to
.. autofunction:: topi.transpose
.. autofunction:: topi.expand_dims
.. autofunction:: topi.max
.. autofunction:: topi.sum
.. autofunction:: topi.min
.. autofunction:: topi.broadcast_to
.. autofunction:: topi.broadcast_add
.. autofunction:: topi.broadcast_sub
.. autofunction:: topi.broadcast_mul
.. autofunction:: topi.broadcast_div
topi.nn
~~~~~~~
.. autofunction:: topi.nn.relu
.. autofunction:: topi.nn.leaky_relu
.. autofunction:: topi.nn.dilate
.. autofunction:: topi.nn.conv2d_nchw
.. autofunction:: topi.nn.conv2d_hwcn
......@@ -71,4 +85,4 @@ topi.cuda
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nhwc
.. autofunction:: topi.cuda.schedule_reduce
.. autofunction:: topi.cuda.schedule_broadcast
.. autofunction:: topi.cuda.schedule_elemwise
.. autofunction:: topi.cuda.schedule_injective
......@@ -304,6 +304,8 @@ class Schedule : public NodeRef {
* as the first dimension. The tensor's body will be rewritten as a reduction
* over the factored tensor.
*
* P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
*
* \param tensor The tensor to be factored.
* \param axis The reduction axis in tensor's schedule to be factored.
* \return The created factored tensors.
......
......@@ -127,6 +127,57 @@ def test_rfactor_threads():
check_target("metal")
check_target("opencl")
def test_rfactor_elemwise_threads():
n = 1025
m = 10
A = tvm.placeholder((m, n), name='A')
k = tvm.reduce_axis((0, n))
nthread = 16
B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name='B')
BB = tvm.compute((m,), lambda i: B[i] + 1, name='BB')
C = tvm.compute((m,), lambda i: BB[i] + 1, name='C')
# schedule
s = tvm.create_schedule(C.op)
s[BB].compute_inline()
bx, ty = s[C].split(s[C].op.axis[0], factor=nthread)
ko, kf = s[B].split(k, factor=nthread)
BF = s.rfactor(B, kf)
s[B].compute_at(s[C], ty)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
tx = s[B].op.reduce_axis[0]
thread_x = tvm.thread_axis("threadIdx.x")
s[B].bind(tx, thread_x)
s[BF].compute_at(s[B], tx)
# Since thread_x is shared across reductions
# only one of them need to do write back
s[B].set_store_predicate(thread_x.var.equal(0))
s[C].set_store_predicate(thread_x.var.equal(0))
# one line to build the function.
def check_target(device, host="stackvm"):
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A, C])
fsum = tvm.build(fapi,
target=device,
name="mysum")
print(fsum.imported_modules[0].get_source())
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=1) + 2
np.testing.assert_allclose(
b.asnumpy(), res, rtol=1e-4)
check_target("cuda")
check_target("metal")
check_target("opencl")
def test_argmax():
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
......@@ -234,6 +285,7 @@ def test_rfactor_argmax():
check_target("cuda")
if __name__ == "__main__":
test_rfactor_elemwise_threads()
test_rfactor_threads()
test_rfactor()
test_reduce_prims()
......
......@@ -18,3 +18,27 @@ def relu(x):
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), 0))
@tvm.tag_scope(tag=tag.ELEMWISE)
def leaky_relu(x, alpha):
"""Take leaky relu of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
alpha : float
The slope for the small gradient when x < 0
Returns
-------
y : tvm.Tensor
The result.
"""
def _compute(*indices):
value = x(*indices)
calpha = tvm.const(alpha, value.dtype)
return tvm.select(value > 0, value, value * calpha)
return tvm.compute(x.shape, _compute)
......@@ -25,3 +25,30 @@ def expand_dims(a, axis, num_newaxis=1):
idx = indices[:axis] + indices[axis + num_newaxis:]
return a(*idx)
return tvm.compute(new_shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def transpose(a, axes=None):
"""Permute the dimensions of an array.
Parameters
----------
a : tvm.Tensor
The tensor to be expanded.
axes: tuple of ints, optional
By default, reverse the dimensions.
Returns
-------
ret : tvm.Tensor
"""
ndim = len(a.shape)
axes = axes if axes else tuple(reversed(range(ndim)))
new_shape = [a.shape[x] for x in axes]
def _compute(*indices):
idx = [1] * len(axes)
for i, k in enumerate(axes):
idx[k] = indices[i]
return a(*idx)
return tvm.compute(new_shape, _compute)
......@@ -27,9 +27,29 @@ def verify_relu(m, n):
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
def verify_leaky_relu(m, alpha):
A = tvm.placeholder((m,), name='A')
B = topi.nn.leaky_relu(A, alpha)
s = tvm.create_schedule([B.op])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0) + a_np * (a_np < 0) * alpha
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], "llvm", name="leaky_relu")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_relu():
verify_relu(10, 128)
def test_leaky_relu():
verify_leaky_relu(100, 0.1)
if __name__ == "__main__":
test_relu()
test_leaky_relu()
......@@ -15,7 +15,6 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = data_npy.reshape(out_shape)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
......@@ -26,9 +25,39 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
check_device("metal")
def verify_tranpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes)
s = topi.cuda.schedule_injective(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B], device, name="tranpose")
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
out_npy = data_npy.transpose(axes)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("cuda")
check_device("opencl")
check_device("metal")
def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
def test_tranpose():
verify_tranpose((3, 10, 2), (1, 0, 2))
verify_tranpose((3, 10, 5), (2, 0, 1))
verify_tranpose((3, 10), None)
if __name__ == "__main__":
test_tranpose()
test_expand_dims()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment