Commit 5a96c9df by Pariksheet Pinjari Committed by Tianqi Chen

[TOPI] PReLU Support (#1008)

parent 36ea5392
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <string> #include <string>
#include "topi/tags.h" #include "topi/tags.h"
#include "topi/detail/constant_utils.h"
#include "tvm/ir.h" #include "tvm/ir.h"
#include "tvm/ir_pass.h" #include "tvm/ir_pass.h"
#include "tvm/tvm.h" #include "tvm/tvm.h"
...@@ -84,6 +85,40 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t, ...@@ -84,6 +85,40 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
} }
/*! /*!
* \brief Creates an operation that performs a parametric rectified linear unit
*
* \param x The input data tensor
* \param slope The channel-wise slope tensor
* \param axis The axis where the channel data needs to be applied
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the relu operation
*/
template <typename T>
inline tvm::Tensor prelu(const tvm::Tensor &x,
const tvm::Tensor &slope,
const int axis = 1,
std::string name = "tensor",
std::string tag = kBroadcast) {
CHECK_EQ(4, x->shape.size());
CHECK((size_t)axis < x->shape.size()) <<
"Wrong axis (" << axis << ")value. ";
CHECK(topi::detail::GetConstInt(slope->shape[0]) ==
topi::detail::GetConstInt(x->shape[axis]))
<< "Wrong slope shape received.";
return tvm::compute(x->shape,
[&](const tvm::Array<tvm::Var> &indices) {
return tvm::select(x(indices) > 0,
x(indices),
x(indices) * slope(indices[axis]));
},
name,
tag);
}
/*!
* \brief Creates an operation that performs padding * \brief Creates an operation that performs padding
* *
* \param t The input tensor * \param t The input tensor
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import tag from .. import tag
from ..util import get_const_int
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def relu(x): def relu(x):
...@@ -42,3 +43,36 @@ def leaky_relu(x, alpha): ...@@ -42,3 +43,36 @@ def leaky_relu(x, alpha):
calpha = tvm.const(alpha, value.dtype) calpha = tvm.const(alpha, value.dtype)
return tvm.select(value > 0, value, value * calpha) return tvm.select(value > 0, value, value * calpha)
return tvm.compute(x.shape, _compute) return tvm.compute(x.shape, _compute)
@tvm.tag_scope(tag=tag.BROADCAST)
def prelu(x, slope, axis=1):
""" PReLU.
It accepts two arguments: an input ``x`` and a weight array ``W``
and computes the output as :math:`PReLU(x) y = x > 0 ? x : W * x`,
where :math:`*` is an elementwise multiplication for each sample in the
batch.
Arguments:
x : tvm.Tensor
Input argument.
slope : tvm.Tensor
Channelised slope tensor for prelu
axis : int
The axis where the channel data needs to be applied
Returns:
y : tvm.Tensor
The result.
Links:
[http://arxiv.org/pdf/1502.01852v1.pdf]
"""
assert len(x.shape) == 4 and len(slope.shape) == 1
assert axis < len(x.shape)
assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis])
def _compute_channelwise(*indices):
return tvm.select(x(*indices) > 0, x(*indices), x(*indices) * slope(indices[axis]))
return tvm.compute(x.shape, _compute_channelwise)
...@@ -190,6 +190,11 @@ TVM_REGISTER_GLOBAL("topi.nn.leaky_relu") ...@@ -190,6 +190,11 @@ TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
*rv = leaky_relu<float>(args[0]); *rv = leaky_relu<float>(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.nn.prelu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = prelu<float>(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.nn.pad") TVM_REGISTER_GLOBAL("topi.nn.pad")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = pad(args[0], args[1], args[2], args[3]); *rv = pad(args[0], args[1], args[2], args[3]);
......
...@@ -46,13 +46,38 @@ def verify_leaky_relu(m, alpha): ...@@ -46,13 +46,38 @@ def verify_leaky_relu(m, alpha):
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_prelu(x, w):
X = tvm.placeholder((x), name='X')
W = tvm.placeholder((w), name='W')
x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
def _prelu_numpy(x, W):
return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x
B = topi.nn.prelu(X, W)
s = tvm.create_schedule([B.op])
ctx = tvm.cpu(0)
x_tvm = tvm.nd.array(x_np, ctx)
w_tvm = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(X.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [X, W, B], "llvm", name="prelu")
foo(x_tvm, w_tvm, b)
out_np = _prelu_numpy(x_np, w_np)
np.testing.assert_allclose(b.asnumpy(), out_np, rtol=1e-5)
def test_relu(): def test_relu():
verify_relu(10, 128) verify_relu(10, 128)
def test_leaky_relu(): def test_leaky_relu():
verify_leaky_relu(100, 0.1) verify_leaky_relu(100, 0.1)
def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))
if __name__ == "__main__": if __name__ == "__main__":
test_relu() test_relu()
test_leaky_relu() test_leaky_relu()
test_prelu()
...@@ -50,6 +50,28 @@ def verify_leaky_relu(m, alpha): ...@@ -50,6 +50,28 @@ def verify_leaky_relu(m, alpha):
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_prelu(x, w):
X = tvm.placeholder((x), name='X')
W = tvm.placeholder((w), name='W')
x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
def _prelu_numpy(x, W):
return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x
out_np = _prelu_numpy(x_np, w_np)
B = topi.cpp.nn.prelu(X, W)
device = "llvm"
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.generic.schedule_injective(target, [B])
ctx = tvm.cpu(0)
x_tvm = tvm.nd.array(x_np, ctx)
w_tvm = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(X.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [X, W, B], "llvm", name="prelu")
foo(x_tvm, w_tvm, b)
np.testing.assert_allclose(b.asnumpy(), out_np, rtol=1e-5)
def test_relu(): def test_relu():
for dtype in ['float32', 'float64', 'int32', 'int16', 'int8', 'int64']: for dtype in ['float32', 'float64', 'int32', 'int16', 'int8', 'int64']:
...@@ -58,7 +80,10 @@ def test_relu(): ...@@ -58,7 +80,10 @@ def test_relu():
def test_leaky_relu(): def test_leaky_relu():
verify_leaky_relu(100, 0.1) verify_leaky_relu(100, 0.1)
def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))
if __name__ == "__main__": if __name__ == "__main__":
test_relu() test_relu()
test_leaky_relu() test_leaky_relu()
test_prelu()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment