Commit 9c299a90 by Pariksheet Pinjari Committed by Tianqi Chen

leaky_relu bug fix (#1218)

parent 429e5fb8
......@@ -466,7 +466,7 @@ NNVM_REGISTER_OP(leaky_relu)
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
return Array<Tensor>{ topi::leaky_relu<float>(inputs[0], 0.0, param.alpha) };
return Array<Tensor>{ topi::leaky_relu(inputs[0], param.alpha) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
......
......@@ -60,17 +60,14 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
* \brief Creates an operation that performs a leaky rectified linear unit
*
* \param t The input tensor
* \param threshold The relu threshold (default 0)
* \param alpha The slope for the small gradient when t < threshold
* \param alpha The slope for the small gradient when t < 0
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the relu operation
*/
template <typename T>
inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
T threshold = static_cast<T>(0),
T alpha = static_cast<T>(0.1),
double alpha = 0.1,
std::string name = "tensor",
std::string tag = kElementWise) {
return tvm::compute(
......
......@@ -191,7 +191,7 @@ TVM_REGISTER_GLOBAL("topi.nn.relu")
TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = leaky_relu<float>(args[0]);
*rv = leaky_relu(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.nn.prelu")
......
......@@ -41,7 +41,7 @@ def verify_leaky_relu(m, alpha):
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.generic.schedule_injective(target, [B])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
a_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0) + a_np * (a_np < 0) * alpha
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
......@@ -78,7 +78,7 @@ def test_relu():
verify_relu(10, 128, dtype)
def test_leaky_relu():
verify_leaky_relu(100, 0.1)
verify_leaky_relu(100, 0.5)
def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment