Commit 32a55f88 by ANSHUMAN TRIPATHY Committed by Tianqi Chen

Prelu bug fix (#1358)

parent ab0d1862
...@@ -563,7 +563,7 @@ where :math:`*` is an channelwise multiplication for each sample in the ...@@ -563,7 +563,7 @@ where :math:`*` is an channelwise multiplication for each sample in the
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed); const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed);
return Array<Tensor>{ topi::prelu<float>(inputs[0], inputs[1], param.axis)}; return Array<Tensor>{ topi::prelu(inputs[0], inputs[1], param.axis)};
}) })
.set_support_level(4); .set_support_level(4);
......
...@@ -92,7 +92,6 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t, ...@@ -92,7 +92,6 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
* *
* \return A Tensor whose op member is the relu operation * \return A Tensor whose op member is the relu operation
*/ */
template <typename T>
inline tvm::Tensor prelu(const tvm::Tensor &x, inline tvm::Tensor prelu(const tvm::Tensor &x,
const tvm::Tensor &slope, const tvm::Tensor &slope,
const int axis = 1, const int axis = 1,
......
...@@ -191,7 +191,7 @@ TVM_REGISTER_GLOBAL("topi.nn.leaky_relu") ...@@ -191,7 +191,7 @@ TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
TVM_REGISTER_GLOBAL("topi.nn.prelu") TVM_REGISTER_GLOBAL("topi.nn.prelu")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = prelu<float>(args[0], args[1]); *rv = prelu(args[0], args[1], args[2]);
}); });
TVM_REGISTER_GLOBAL("topi.nn.pad") TVM_REGISTER_GLOBAL("topi.nn.pad")
......
...@@ -46,16 +46,16 @@ def verify_leaky_relu(m, alpha): ...@@ -46,16 +46,16 @@ def verify_leaky_relu(m, alpha):
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_prelu(x, w): def verify_prelu(x, w, axis, weight_reshape):
X = tvm.placeholder((x), name='X') X = tvm.placeholder((x), name='X')
W = tvm.placeholder((w), name='W') W = tvm.placeholder((w), name='W')
x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype) x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype) w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
def _prelu_numpy(x, W): def _prelu_numpy(x, W):
return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x
B = topi.nn.prelu(X, W) B = topi.nn.prelu(X, W, axis)
s = tvm.create_schedule([B.op]) s = tvm.create_schedule([B.op])
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
...@@ -79,7 +79,8 @@ def test_leaky_relu(): ...@@ -79,7 +79,8 @@ def test_leaky_relu():
verify_leaky_relu(100, 0.1) verify_leaky_relu(100, 0.1)
def test_prelu(): def test_prelu():
verify_prelu((1, 3, 2, 2), (3,)) verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_big_array() test_schedule_big_array()
......
...@@ -50,16 +50,16 @@ def verify_leaky_relu(m, alpha): ...@@ -50,16 +50,16 @@ def verify_leaky_relu(m, alpha):
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_prelu(x, w): def verify_prelu(x, w, axis, weight_reshape):
X = tvm.placeholder((x), name='X') X = tvm.placeholder((x), name='X')
W = tvm.placeholder((w), name='W') W = tvm.placeholder((w), name='W')
x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype) x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype) w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
def _prelu_numpy(x, W): def _prelu_numpy(x, W):
return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x
out_np = _prelu_numpy(x_np, w_np) out_np = _prelu_numpy(x_np, w_np)
B = topi.cpp.nn.prelu(X, W) B = topi.cpp.nn.prelu(X, W, axis)
device = "llvm" device = "llvm"
target = topi.cpp.TEST_create_target(device) target = topi.cpp.TEST_create_target(device)
s = topi.cpp.generic.schedule_injective(target, [B]) s = topi.cpp.generic.schedule_injective(target, [B])
...@@ -81,7 +81,8 @@ def test_leaky_relu(): ...@@ -81,7 +81,8 @@ def test_leaky_relu():
verify_leaky_relu(100, 0.5) verify_leaky_relu(100, 0.5)
def test_prelu(): def test_prelu():
verify_prelu((1, 3, 2, 2), (3,)) verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
if __name__ == "__main__": if __name__ == "__main__":
test_relu() test_relu()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment