Commit 65f87264 by Leyuan Wang Committed by Tianqi Chen

log_softmax added to topi (#483)

parent 489ec872
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""TVM operator softmax compute.""" """TVM operator for softmax and log_softmax compute."""
from __future__ import absolute_import from __future__ import absolute_import
import tvm import tvm
...@@ -26,3 +26,28 @@ def softmax(x): ...@@ -26,3 +26,28 @@ def softmax(x):
(m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k)) (m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
return tvm.compute( return tvm.compute(
x.shape, lambda i, j: tvm.exp(x[i, j] - max_elem[i]) / expsum[i]) x.shape, lambda i, j: tvm.exp(x[i, j] - max_elem[i]) / expsum[i])
@tvm.tag_scope(tag='log_softmax_output')
def log_softmax(x):
"""Perform log softmax activation on the data
Parameters
----------
data : tvm.Tensor
2-D input data
Returns
-------
output : tvm.Tensor
2-D output with same shape
"""
assert len(x.shape) == 2, "only support 2-dim log softmax"
m, n = x.shape
k = tvm.reduce_axis((0, n), name='k')
max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k))
k = tvm.reduce_axis((0, n), name='k')
expsum = tvm.compute(
(m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
return tvm.compute(
x.shape, lambda i, j: x[i, j] - max_elem[i] - tvm.log(expsum[i]))
...@@ -8,4 +8,4 @@ from .conv2d_hwcn_python import conv2d_hwcn_python ...@@ -8,4 +8,4 @@ from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nchw_python import conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python from .dilate_python import dilate_python
from .softmax_python import softmax_python from .softmax_python import softmax_python, log_softmax_python
# pylint: disable=invalid-name, trailing-whitespace # pylint: disable=invalid-name, trailing-whitespace
"""Softmax operation in python""" """Softmax and log_softmax operation in python"""
import numpy as np import numpy as np
def softmax_python(a_np): def softmax_python(a_np):
...@@ -21,3 +21,23 @@ def softmax_python(a_np): ...@@ -21,3 +21,23 @@ def softmax_python(a_np):
expsum = np.sum(e, axis=1) expsum = np.sum(e, axis=1)
out_np = e / expsum[:, None] out_np = e / expsum[:, None]
return out_np return out_np
def log_softmax_python(a_np):
"""Log_softmax operator.
Parameters
----------
a_np : numpy.ndarray
2-D input data
Returns
-------
output_np : numpy.ndarray
2-D output with same shape
"""
assert len(a_np.shape) == 2, "only support 2-dim log_softmax"
max_elem = np.amax(a_np, axis=1)
max_elem = max_elem.reshape(max_elem.shape[0], 1)
e = np.exp(a_np-max_elem)
expsum = np.sum(e, axis=1)
out_np = a_np - max_elem - np.log(expsum[:, None])
return out_np
...@@ -36,5 +36,36 @@ def test_softmax(): ...@@ -36,5 +36,36 @@ def test_softmax():
verify_softmax(3, 4) verify_softmax(3, 4)
def verify_log_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.nn.log_softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
s = topi.cuda.schedule_softmax(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="log_softmax")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
def test_log_softmax():
verify_log_softmax(32, 10)
verify_log_softmax(3, 4)
if __name__ == "__main__": if __name__ == "__main__":
test_softmax() test_softmax()
test_log_softmax()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment