Commit 3708b311 by masahi Committed by Tianqi Chen

Update cuda softmax schedule for spatial inputs (#2338)

parent f6c3f997
......@@ -2,6 +2,7 @@
"""Schedule for softmax operator"""
import tvm
from .. import generic
from .injective import _schedule_injective
@generic.schedule_softmax.register(["cuda", "gpu"])
def schedule_softmax(outs):
......@@ -24,21 +25,24 @@ def schedule_softmax(outs):
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
num_thread = 64
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
if len(softmax.shape) > 2:
for op in [max_elem.op, expsum.op, softmax.op]:
s = _schedule_injective(op, s)
else:
num_thread = 64
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0]
ko, ki = s[expsum].split(k, factor=num_thread)
EF = s.rfactor(expsum, ki)
s[expsum].bind(s[expsum].op.axis[0], block_x)
s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x)
s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0])
s[expsum].set_store_predicate(thread_x.var.equal(0))
tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
s[softmax].bind(softmax.op.axis[0], block_x)
s[softmax].bind(tx, thread_x)
s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0]
ko, ki = s[expsum].split(k, factor=num_thread)
EF = s.rfactor(expsum, ki)
s[expsum].bind(s[expsum].op.axis[0], block_x)
s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x)
s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0])
s[expsum].set_store_predicate(thread_x.var.equal(0))
tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
s[softmax].bind(softmax.op.axis[0], block_x)
s[softmax].bind(tx, thread_x)
return s
......@@ -9,6 +9,21 @@ from topi.util import get_const_tuple
from common import get_all_backend
def check_device(A, B, a_np, b_np, device, name):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
f = tvm.build(s, [A, B], device, name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_softmax(m, n, dtype="float32"):
A = tvm.placeholder((m, n), dtype=dtype, name='A')
B = topi.nn.softmax(A)
......@@ -19,28 +34,26 @@ def verify_softmax(m, n, dtype="float32"):
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(A, B, a_np, b_np, device, "softmax")
def verify_softmax_4d(shape, dtype="float32"):
A = tvm.placeholder(shape, dtype=dtype, name='A')
B = topi.nn.softmax(A, axis=1)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="softmax")
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
_, c, h, w = shape
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c))
b_np = b_np.reshape(1, h, w, c).transpose(0, 3, 1, 2)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)
check_device(A, B, a_np, b_np, device, "softmax")
def test_softmax():
verify_softmax(32, 10)
verify_softmax(3, 4)
verify_softmax(32, 10, "float64")
verify_softmax_4d((1, 16, 256, 256))
def verify_log_softmax(m, n, dtype="float32"):
A = tvm.placeholder((m, n), dtype=dtype, name='A')
......@@ -51,22 +64,8 @@ def verify_log_softmax(m, n, dtype="float32"):
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="log_softmax")
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in get_all_backend():
check_device(device)
check_device(A, B, a_np, b_np, device, "log_softmax")
def test_log_softmax():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment