Commit 96b2c082 by Pariksheet Pinjari Committed by Tianqi Chen

Added equality check and upgraded concatenate op (#1172)

parent 05e806e0
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <string> #include <string>
#include "topi/detail/broadcast.h" #include "topi/detail/broadcast.h"
#include "topi/detail/constant_utils.h"
#include "topi/tags.h" #include "topi/tags.h"
namespace topi { namespace topi {
...@@ -34,7 +35,7 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t, ...@@ -34,7 +35,7 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
auto bh = detail::BroadcastShape(output_shape, t->shape); auto bh = detail::BroadcastShape(output_shape, t->shape);
CHECK_EQ(output_shape.size(), bh.common_shape.size()); CHECK_EQ(output_shape.size(), bh.common_shape.size());
for (size_t i = 0; i < output_shape.size(); ++i) { for (size_t i = 0; i < output_shape.size(); ++i) {
CHECK(tvm::ir::Equal(output_shape[i], bh.common_shape[i])); CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
} }
auto l = [&](tvm::Array<tvm::Var> ovars) { auto l = [&](tvm::Array<tvm::Var> ovars) {
return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "tvm/ir_pass.h" #include "tvm/ir_pass.h"
#include "tvm/tvm.h" #include "tvm/tvm.h"
#include "topi/detail/constant_utils.h"
namespace topi { namespace topi {
namespace detail { namespace detail {
...@@ -32,15 +33,15 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1, ...@@ -32,15 +33,15 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
int i; int i;
for (i = 1; i <= std::min(s1_size, s2_size); ++i) { for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
bh.all_vars.push_front(tvm::Var()); bh.all_vars.push_front(tvm::Var());
if (tvm::ir::Equal(shape1[s1_size - i], shape2[s2_size - i])) { if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]); bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]); bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]);
} else if (tvm::ir::Equal(one, shape1[s1_size - i])) { } else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) {
CHECK(!tvm::ir::Equal(one, shape2[s2_size - i])); CHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i]));
bh.common_shape.push_front(shape2[s2_size - i]); bh.common_shape.push_front(shape2[s2_size - i]);
bh.vars2.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]);
} else if (tvm::ir::Equal(one, shape2[s2_size - i])) { } else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]); bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]); bh.vars1.push_front(bh.all_vars[0]);
} else { } else {
......
...@@ -65,6 +65,24 @@ inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string& ...@@ -65,6 +65,24 @@ inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string&
return result; return result;
} }
/*!
* \brief Check weather the two expressions are equal or not, if not simplify the expressions and check again
* \note This is stronger equality check than tvm::ir::Equal
*
* \param lhs First expreesion
* \param rhs Second expreesion
*
* \return result True if both expressions are equal, else false
*/
inline bool EqualCheck(Expr lhs, Expr rhs) {
bool result = tvm::ir::Equal(lhs, rhs);
if (!result) {
Expr zero(0);
result = tvm::ir::Equal(tvm::ir::CanonicalSimplify(lhs-rhs), zero);
}
return result;
}
} // namespace detail } // namespace detail
} // namespace topi } // namespace topi
#endif // TOPI_DETAIL_CONSTANT_UTILS_H_ #endif // TOPI_DETAIL_CONSTANT_UTILS_H_
...@@ -186,13 +186,13 @@ inline tvm::Tensor pad(const tvm::Tensor& t, ...@@ -186,13 +186,13 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
indices.push_back(ovars[i]); indices.push_back(ovars[i]);
continue; continue;
} }
if (!tvm::ir::Equal(pad_before[i], 0)) { if (!topi::detail::EqualCheck(pad_before[i], 0)) {
sel.push_back(ovars[i] >= pad_before[i]); sel.push_back(ovars[i] >= pad_before[i]);
indices.push_back(ovars[i] - pad_before[i]); indices.push_back(ovars[i] - pad_before[i]);
} else { } else {
indices.push_back(ovars[i]); indices.push_back(ovars[i]);
} }
if (!tvm::ir::Equal(pad_after[i], 0)) { if (!topi::detail::EqualCheck(pad_after[i], 0)) {
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before[i] + t->shape[i])); sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before[i] + t->shape[i]));
} }
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "topi/detail/ravel_unravel.h" #include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h" #include "topi/detail/constant_utils.h"
#include "tvm/tvm.h" #include "tvm/tvm.h"
#include "tvm/ir_pass.h"
namespace topi { namespace topi {
using namespace tvm; using namespace tvm;
...@@ -260,6 +261,7 @@ inline Tensor concatenate(const Array<Tensor>& inputs, ...@@ -260,6 +261,7 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
for (size_t i = 1; i < axis_sizes.size(); ++i) { for (size_t i = 1; i < axis_sizes.size(); ++i) {
join_size += axis_sizes[i]; join_size += axis_sizes[i];
} }
join_size = tvm::ir::Simplify(join_size);
Array<Expr> out_shape; Array<Expr> out_shape;
for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]); out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
......
...@@ -226,6 +226,7 @@ def concatenate(a_tuple, axis=0): ...@@ -226,6 +226,7 @@ def concatenate(a_tuple, axis=0):
axis_sizes = [a_tuple[i].shape[axis] for i in range(len(a_tuple))] axis_sizes = [a_tuple[i].shape[axis] for i in range(len(a_tuple))]
out_shape = [a_tuple[0].shape[i] for i in range(0, axis)] + [sum(axis_sizes)]\ out_shape = [a_tuple[0].shape[i] for i in range(0, axis)] + [sum(axis_sizes)]\
+ [a_tuple[0].shape[i] for i in range(axis + 1, len(a_tuple[0].shape))] + [a_tuple[0].shape[i] for i in range(axis + 1, len(a_tuple[0].shape))]
out_shape[axis] = tvm.ir_pass.Simplify(out_shape[axis])
def _compute(*indices): def _compute(*indices):
ret = a_tuple[0](*indices) ret = a_tuple[0](*indices)
......
...@@ -206,6 +206,70 @@ def verify_take(src_shape, indices_src, axis=None): ...@@ -206,6 +206,70 @@ def verify_take(src_shape, indices_src, axis=None):
for device in ["llvm", "opencl"]: for device in ["llvm", "opencl"]:
check_device(device) check_device(device)
def verify_concatenate_split(shapes, axis, indices_or_sections):
tensor_l_concatenate = []
for i, shape in enumerate(shapes):
tensor_l_concatenate.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.cpp.concatenate(tensor_l_concatenate, axis)
tensor_l = topi.cpp.split(out_tensor, indices_or_sections, axis)
tensor_l = list(tensor_l)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.schedule_injective(target, tensor_l)
else:
s = topi.cpp.cuda.schedule_injective(target, tensor_l)
ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l_concatenate + tensor_l, device, name="concatenate_split")
data_npys = [np.random.normal(size=shape).astype(tensor_l_concatenate[0].dtype) for shape in shapes]
out_npy_conc = np.concatenate(data_npys, axis=axis)
out_npys_split = np.split(out_npy_conc, indices_or_sections, axis=axis)
data_nds = [tvm.nd.array(data_npy, ctx) for data_npy in data_npys]
out_nds = [tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=tensor_l[0].dtype) for out_npy in out_npys_split]
foo(*(data_nds + out_nds))
for out_nd, out_npy in zip(out_nds, out_npys_split):
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def verify_concatenate_broadcast(shapes, axis, rhs_shape):
B = tvm.placeholder(shape=rhs_shape, name="B")
tensor_l = []
for i, shape in enumerate(shapes):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.cpp.concatenate(tensor_l, axis)
C = topi.cpp.broadcast_add(out_tensor, B)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.schedule_injective(target, [C])
else:
s = topi.cpp.cuda.schedule_injective(target, [C])
ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l + [B, C], device, name="broadcast_binary_add")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
lhs_npy = np.concatenate(data_npys, axis=axis)
rhs_npy = np.random.uniform(size=rhs_shape).astype(B.dtype)
out_npy = lhs_npy + rhs_npy
data_nds = [tvm.nd.array(data_npy, ctx) for data_npy in data_npys]
rhs_nd = tvm.nd.array(rhs_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
for _ in range(1):
foo(*(data_nds + [rhs_nd] + [out_nd]))
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
for device in ["llvm", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def test_expand_dims(): def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
...@@ -258,6 +322,14 @@ def test_take(): ...@@ -258,6 +322,14 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2) verify_take((4,3,5,6), [[2,1,0,0]], -2)
def test_regression_1():
verify_concatenate_split([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1, [3, 7])
verify_concatenate_split([(3, 4), (2, 4), (3, 4)], 0, [1, 2, 3, 4])
def test_regression_2():
verify_concatenate_broadcast([(5, 1, 3), (5, 1, 3)], 1, [2, 1])
verify_concatenate_broadcast([(5, 1, 2), (5, 1, 3)], 2, [1, 5])
if __name__ == "__main__": if __name__ == "__main__":
test_concatenate() test_concatenate()
test_tranpose() test_tranpose()
...@@ -266,3 +338,5 @@ if __name__ == "__main__": ...@@ -266,3 +338,5 @@ if __name__ == "__main__":
test_squeeze() test_squeeze()
test_split() test_split()
test_take() test_take()
test_regression_1()
test_regression_2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment