Commit a961b29c by Tianqi Chen Committed by GitHub

[TOPI] Fix softmax bug (#437)

parent 0c9adc5b
......@@ -21,6 +21,7 @@ def softmax(x):
m, n = x.shape
k = tvm.reduce_axis((0, n), name='k')
max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k))
k = tvm.reduce_axis((0, n), name='k')
expsum = tvm.compute(
(m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
return tvm.compute(
......
......@@ -6,9 +6,12 @@ import topi
from topi.util import get_const_tuple
def verify_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.nn.softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
s = topi.cuda.schedule_softmax(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
......@@ -26,10 +29,11 @@ def verify_softmax(m, n):
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
check_device(device)
def test_softmax():
verify_softmax(32, 10)
verify_softmax(3, 4)
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment