Commit 3708b311 by masahi Committed by Tianqi Chen

Update cuda softmax schedule for spatial inputs (#2338)

parent f6c3f997
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""Schedule for softmax operator""" """Schedule for softmax operator"""
import tvm import tvm
from .. import generic from .. import generic
from .injective import _schedule_injective
@generic.schedule_softmax.register(["cuda", "gpu"]) @generic.schedule_softmax.register(["cuda", "gpu"])
def schedule_softmax(outs): def schedule_softmax(outs):
...@@ -24,21 +25,24 @@ def schedule_softmax(outs): ...@@ -24,21 +25,24 @@ def schedule_softmax(outs):
max_elem = softmax.op.input_tensors[1] max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2] expsum = softmax.op.input_tensors[2]
num_thread = 64 if len(softmax.shape) > 2:
block_x = tvm.thread_axis("blockIdx.x") for op in [max_elem.op, expsum.op, softmax.op]:
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") s = _schedule_injective(op, s)
else:
num_thread = 64
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
s[max_elem].bind(max_elem.op.axis[0], block_x) s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0]
k = expsum.op.reduce_axis[0] ko, ki = s[expsum].split(k, factor=num_thread)
ko, ki = s[expsum].split(k, factor=num_thread) EF = s.rfactor(expsum, ki)
EF = s.rfactor(expsum, ki) s[expsum].bind(s[expsum].op.axis[0], block_x)
s[expsum].bind(s[expsum].op.axis[0], block_x) s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x)
s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x) s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0])
s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0]) s[expsum].set_store_predicate(thread_x.var.equal(0))
s[expsum].set_store_predicate(thread_x.var.equal(0)) tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread) s[softmax].bind(softmax.op.axis[0], block_x)
s[softmax].bind(softmax.op.axis[0], block_x) s[softmax].bind(tx, thread_x)
s[softmax].bind(tx, thread_x)
return s return s
...@@ -9,6 +9,21 @@ from topi.util import get_const_tuple ...@@ -9,6 +9,21 @@ from topi.util import get_const_tuple
from common import get_all_backend from common import get_all_backend
def check_device(A, B, a_np, b_np, device, name):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
f = tvm.build(s, [A, B], device, name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_softmax(m, n, dtype="float32"): def verify_softmax(m, n, dtype="float32"):
A = tvm.placeholder((m, n), dtype=dtype, name='A') A = tvm.placeholder((m, n), dtype=dtype, name='A')
B = topi.nn.softmax(A) B = topi.nn.softmax(A)
...@@ -19,28 +34,26 @@ def verify_softmax(m, n, dtype="float32"): ...@@ -19,28 +34,26 @@ def verify_softmax(m, n, dtype="float32"):
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np) b_np = topi.testing.softmax_python(a_np)
def check_device(device): for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
ctx = tvm.context(device, 0) check_device(A, B, a_np, b_np, device, "softmax")
if not ctx.exist:
print("Skip because %s is not enabled" % device) def verify_softmax_4d(shape, dtype="float32"):
return A = tvm.placeholder(shape, dtype=dtype, name='A')
print("Running on target: %s" % device) B = topi.nn.softmax(A, axis=1)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
a = tvm.nd.array(a_np, ctx) _, c, h, w = shape
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
foo = tvm.build(s, [A, B], device, name="softmax") b_np = topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c))
foo(a, b) b_np = b_np.reshape(1, h, w, c).transpose(0, 3, 1, 2)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device) check_device(A, B, a_np, b_np, device, "softmax")
def test_softmax(): def test_softmax():
verify_softmax(32, 10) verify_softmax(32, 10)
verify_softmax(3, 4) verify_softmax(3, 4)
verify_softmax(32, 10, "float64") verify_softmax(32, 10, "float64")
verify_softmax_4d((1, 16, 256, 256))
def verify_log_softmax(m, n, dtype="float32"): def verify_log_softmax(m, n, dtype="float32"):
A = tvm.placeholder((m, n), dtype=dtype, name='A') A = tvm.placeholder((m, n), dtype=dtype, name='A')
...@@ -51,22 +64,8 @@ def verify_log_softmax(m, n, dtype="float32"): ...@@ -51,22 +64,8 @@ def verify_log_softmax(m, n, dtype="float32"):
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np) b_np = topi.testing.log_softmax_python(a_np)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="log_softmax")
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(A, B, a_np, b_np, device, "log_softmax")
def test_log_softmax(): def test_log_softmax():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment