Commit b3bb8126 by Leyuan Wang Committed by Tianqi Chen

Softmax operator migrated to topi (#366)

* softmax migrated and test added

* pylint error fixed

* pylint error fixed
parent 2a87020c
......@@ -7,3 +7,4 @@ from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to
from .softmax import schedule_softmax
......@@ -115,7 +115,7 @@ def schedule_conv2d_small_batch(outs):
return s
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw and any element-wise operations.
"""Schedule for conv2d_nchw.
Parameters
----------
......
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator"""
import tvm
def schedule_softmax(outs):
"""Schedule for softmax op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
softmax = outs[0]
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
num_thread = 64
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0]
ko, ki = s[expsum].split(k, factor=num_thread)
EF = s.rfactor(expsum, ki)
s[expsum].bind(s[expsum].op.axis[0], block_x)
s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x)
s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0])
tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
s[softmax].bind(softmax.op.axis[0], block_x)
s[softmax].bind(tx, thread_x)
return s
# pylint: disable=invalid-name
"""TVM operator softmax compute."""
from __future__ import absolute_import
import tvm
@tvm.tag_scope(tag='softmax')
@tvm.tag_scope(tag='softmax_output')
def softmax(x):
"""Perform softmax activation on the data
......
......@@ -8,3 +8,4 @@ from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
from .softmax_python import softmax_python
# pylint: disable=invalid-name, trailing-whitespace
"""Softmax operation in python"""
import numpy as np
def softmax_python(a_np):
"""Softmax operator.
Parameters
----------
a_np : numpy.ndarray
2-D input data
Returns
-------
output_np : numpy.ndarray
2-D output with same shape
"""
assert len(a_np.shape) == 2, "only support 2-dim softmax"
max_elem = np.amax(a_np, axis=1)
max_elem = max_elem.reshape(max_elem.shape[0], 1)
e = np.exp(a_np-max_elem)
expsum = np.sum(e, axis=1)
out_np = e / expsum[:, None]
return out_np
"""Test code for softmax"""
import os
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def verify_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.nn.softmax(A)
s = topi.cuda.schedule_softmax(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="softmax")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
def test_softmax():
verify_softmax(32, 10)
if __name__ == "__main__":
test_softmax()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment