Commit 0ca53640 by alex-weaver Committed by Tianqi Chen

Fix bugs with C++ TOPI flatten and relu (#869)

* Fix bugs with C++ TOPI flatten and relu

* Added regression tests. Fixed typo in CMakeLists.txt. Fixed topi cpp import removed.
parent 0ba579f5
......@@ -193,7 +193,7 @@ if(USE_GRAPH_RUNTIME)
endif(USE_GRAPH_RUNTIME)
if(USE_LLVM)
find_spackage(LLVM CONFIG REQUIRED)
find_package(LLVM CONFIG REQUIRED)
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR})
......
......@@ -47,7 +47,10 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
std::string tag = kElementWise) {
return tvm::compute(
t->shape,
[&](const tvm::Array<tvm::Var>& i) { return tvm::max(t(i), threshold); },
[&](const tvm::Array<tvm::Var>& i) {
auto threshold_const = tvm::make_const(t->dtype, threshold);
return tvm::max(t(i), threshold_const);
},
name,
tag);
}
......
......@@ -55,7 +55,7 @@ inline Tensor flatten(const Tensor& x,
index.push_back(i);
std::reverse(index.begin(), index.end());
return x(index);
});
}, name, tag);
}
} // namespace nn
......
......@@ -24,3 +24,4 @@ from . import mali
from . import testing
from . import util
from . import rocm
from . import cpp
......@@ -25,7 +25,12 @@ def test_ewise():
test_apply(topi.cpp.log, "log")
test_apply(topi.cpp.sqrt, "sqrt")
def test_flatten_tag():
A = tvm.placeholder((3, 4), name='A')
B = topi.cpp.nn.flatten(A)
assert B.op.tag == topi.tag.INJECTIVE
if __name__ == "__main__":
test_util()
test_ewise()
test_flatten_tag()
......@@ -5,9 +5,10 @@ import tvm
import topi
from topi.util import get_const_tuple
def verify_relu(m, n):
A = tvm.placeholder((m, n), name='A')
def verify_relu(m, n, dtype):
A = tvm.placeholder((m, n), name='A', dtype=dtype)
B = topi.cpp.nn.relu(A)
assert B.dtype == dtype
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0)
......@@ -51,7 +52,8 @@ def verify_leaky_relu(m, alpha):
def test_relu():
verify_relu(10, 128)
for dtype in ['float32', 'float64', 'int32', 'int16', 'int8', 'int64']:
verify_relu(10, 128, dtype)
def test_leaky_relu():
verify_leaky_relu(100, 0.1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment