Commit 3708b311 by masahi Committed by Tianqi Chen

Update cuda softmax schedule for spatial inputs (#2338)

parent f6c3f997
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""Schedule for softmax operator""" """Schedule for softmax operator"""
import tvm import tvm
from .. import generic from .. import generic
from .injective import _schedule_injective
@generic.schedule_softmax.register(["cuda", "gpu"]) @generic.schedule_softmax.register(["cuda", "gpu"])
def schedule_softmax(outs): def schedule_softmax(outs):
...@@ -24,12 +25,15 @@ def schedule_softmax(outs): ...@@ -24,12 +25,15 @@ def schedule_softmax(outs):
max_elem = softmax.op.input_tensors[1] max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2] expsum = softmax.op.input_tensors[2]
if len(softmax.shape) > 2:
for op in [max_elem.op, expsum.op, softmax.op]:
s = _schedule_injective(op, s)
else:
num_thread = 64 num_thread = 64
block_x = tvm.thread_axis("blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
s[max_elem].bind(max_elem.op.axis[0], block_x) s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0] k = expsum.op.reduce_axis[0]
ko, ki = s[expsum].split(k, factor=num_thread) ko, ki = s[expsum].split(k, factor=num_thread)
EF = s.rfactor(expsum, ki) EF = s.rfactor(expsum, ki)
......
...@@ -9,17 +9,7 @@ from topi.util import get_const_tuple ...@@ -9,17 +9,7 @@ from topi.util import get_const_tuple
from common import get_all_backend from common import get_all_backend
def verify_softmax(m, n, dtype="float32"): def check_device(A, B, a_np, b_np, device, name):
A = tvm.placeholder((m, n), dtype=dtype, name='A')
B = topi.nn.softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np)
def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
...@@ -30,17 +20,40 @@ def verify_softmax(m, n, dtype="float32"): ...@@ -30,17 +20,40 @@ def verify_softmax(m, n, dtype="float32"):
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="softmax") f = tvm.build(s, [A, B], device, name="softmax")
foo(a, b) f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_softmax(m, n, dtype="float32"):
A = tvm.placeholder((m, n), dtype=dtype, name='A')
B = topi.nn.softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device) check_device(A, B, a_np, b_np, device, "softmax")
def verify_softmax_4d(shape, dtype="float32"):
A = tvm.placeholder(shape, dtype=dtype, name='A')
B = topi.nn.softmax(A, axis=1)
_, c, h, w = shape
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c))
b_np = b_np.reshape(1, h, w, c).transpose(0, 3, 1, 2)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(A, B, a_np, b_np, device, "softmax")
def test_softmax(): def test_softmax():
verify_softmax(32, 10) verify_softmax(32, 10)
verify_softmax(3, 4) verify_softmax(3, 4)
verify_softmax(32, 10, "float64") verify_softmax(32, 10, "float64")
verify_softmax_4d((1, 16, 256, 256))
def verify_log_softmax(m, n, dtype="float32"): def verify_log_softmax(m, n, dtype="float32"):
A = tvm.placeholder((m, n), dtype=dtype, name='A') A = tvm.placeholder((m, n), dtype=dtype, name='A')
...@@ -51,22 +64,8 @@ def verify_log_softmax(m, n, dtype="float32"): ...@@ -51,22 +64,8 @@ def verify_log_softmax(m, n, dtype="float32"):
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np) b_np = topi.testing.log_softmax_python(a_np)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="log_softmax")
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(A, B, a_np, b_np, device, "log_softmax")
def test_log_softmax(): def test_log_softmax():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment