Commit 0ca53640 by alex-weaver Committed by Tianqi Chen

Fix bugs with C++ TOPI flatten and relu (#869)

* Fix bugs with C++ TOPI flatten and relu

* Added regression tests. Fixed typo in CMakeLists.txt. Fixed topi cpp import removed.
parent 0ba579f5
...@@ -193,7 +193,7 @@ if(USE_GRAPH_RUNTIME) ...@@ -193,7 +193,7 @@ if(USE_GRAPH_RUNTIME)
endif(USE_GRAPH_RUNTIME) endif(USE_GRAPH_RUNTIME)
if(USE_LLVM) if(USE_LLVM)
find_spackage(LLVM CONFIG REQUIRED) find_package(LLVM CONFIG REQUIRED)
include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS}) add_definitions(${LLVM_DEFINITIONS})
set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR}) set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR})
......
...@@ -47,7 +47,10 @@ inline tvm::Tensor relu(const tvm::Tensor& t, ...@@ -47,7 +47,10 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
std::string tag = kElementWise) { std::string tag = kElementWise) {
return tvm::compute( return tvm::compute(
t->shape, t->shape,
[&](const tvm::Array<tvm::Var>& i) { return tvm::max(t(i), threshold); }, [&](const tvm::Array<tvm::Var>& i) {
auto threshold_const = tvm::make_const(t->dtype, threshold);
return tvm::max(t(i), threshold_const);
},
name, name,
tag); tag);
} }
......
...@@ -55,7 +55,7 @@ inline Tensor flatten(const Tensor& x, ...@@ -55,7 +55,7 @@ inline Tensor flatten(const Tensor& x,
index.push_back(i); index.push_back(i);
std::reverse(index.begin(), index.end()); std::reverse(index.begin(), index.end());
return x(index); return x(index);
}); }, name, tag);
} }
} // namespace nn } // namespace nn
......
...@@ -24,3 +24,4 @@ from . import mali ...@@ -24,3 +24,4 @@ from . import mali
from . import testing from . import testing
from . import util from . import util
from . import rocm from . import rocm
from . import cpp
...@@ -25,7 +25,12 @@ def test_ewise(): ...@@ -25,7 +25,12 @@ def test_ewise():
test_apply(topi.cpp.log, "log") test_apply(topi.cpp.log, "log")
test_apply(topi.cpp.sqrt, "sqrt") test_apply(topi.cpp.sqrt, "sqrt")
def test_flatten_tag():
A = tvm.placeholder((3, 4), name='A')
B = topi.cpp.nn.flatten(A)
assert B.op.tag == topi.tag.INJECTIVE
if __name__ == "__main__": if __name__ == "__main__":
test_util() test_util()
test_ewise() test_ewise()
test_flatten_tag()
...@@ -5,9 +5,10 @@ import tvm ...@@ -5,9 +5,10 @@ import tvm
import topi import topi
from topi.util import get_const_tuple from topi.util import get_const_tuple
def verify_relu(m, n): def verify_relu(m, n, dtype):
A = tvm.placeholder((m, n), name='A') A = tvm.placeholder((m, n), name='A', dtype=dtype)
B = topi.cpp.nn.relu(A) B = topi.cpp.nn.relu(A)
assert B.dtype == dtype
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0) b_np = a_np * (a_np > 0)
...@@ -51,7 +52,8 @@ def verify_leaky_relu(m, alpha): ...@@ -51,7 +52,8 @@ def verify_leaky_relu(m, alpha):
def test_relu(): def test_relu():
verify_relu(10, 128) for dtype in ['float32', 'float64', 'int32', 'int16', 'int8', 'int64']:
verify_relu(10, 128, dtype)
def test_leaky_relu(): def test_leaky_relu():
verify_leaky_relu(100, 0.1) verify_leaky_relu(100, 0.1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment