Commit 32a55f88 by ANSHUMAN TRIPATHY Committed by Tianqi Chen

Prelu bug fix (#1358)

parent ab0d1862
......@@ -563,7 +563,7 @@ where :math:`*` is an channelwise multiplication for each sample in the
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed);
return Array<Tensor>{ topi::prelu<float>(inputs[0], inputs[1], param.axis)};
return Array<Tensor>{ topi::prelu(inputs[0], inputs[1], param.axis)};
})
.set_support_level(4);
......
......@@ -92,7 +92,6 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
*
* \return A Tensor whose op member is the relu operation
*/
template <typename T>
inline tvm::Tensor prelu(const tvm::Tensor &x,
const tvm::Tensor &slope,
const int axis = 1,
......
......@@ -191,7 +191,7 @@ TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
TVM_REGISTER_GLOBAL("topi.nn.prelu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = prelu<float>(args[0], args[1]);
*rv = prelu(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.nn.pad")
......
......@@ -46,16 +46,16 @@ def verify_leaky_relu(m, alpha):
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_prelu(x, w):
def verify_prelu(x, w, axis, weight_reshape):
X = tvm.placeholder((x), name='X')
W = tvm.placeholder((w), name='W')
x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
def _prelu_numpy(x, W):
return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x
return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x
B = topi.nn.prelu(X, W)
B = topi.nn.prelu(X, W, axis)
s = tvm.create_schedule([B.op])
ctx = tvm.cpu(0)
......@@ -79,7 +79,8 @@ def test_leaky_relu():
verify_leaky_relu(100, 0.1)
def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))
verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
if __name__ == "__main__":
test_schedule_big_array()
......
......@@ -50,16 +50,16 @@ def verify_leaky_relu(m, alpha):
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_prelu(x, w):
def verify_prelu(x, w, axis, weight_reshape):
X = tvm.placeholder((x), name='X')
W = tvm.placeholder((w), name='W')
x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
def _prelu_numpy(x, W):
return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x
return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x
out_np = _prelu_numpy(x_np, w_np)
B = topi.cpp.nn.prelu(X, W)
B = topi.cpp.nn.prelu(X, W, axis)
device = "llvm"
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.generic.schedule_injective(target, [B])
......@@ -81,7 +81,8 @@ def test_leaky_relu():
verify_leaky_relu(100, 0.5)
def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))
verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
if __name__ == "__main__":
test_relu()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment