Commit b590c4f2 by Yutetsu TAKATSUKASA Committed by Tianqi Chen

Consistent result of DetectLinearEquation() when an empy vars is passed (#2860)

parent c162e7d6
...@@ -127,25 +127,21 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) { ...@@ -127,25 +127,21 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
Expr base = e; Expr base = e;
Array<Expr> coeff; Array<Expr> coeff;
if (0 == vars.size()) { for (Var v : vars) {
coeff.push_back(make_const(Int(32), 1)); LinearEqEntry ret;
} else { if (!LinearEqDetector(v).Detect(base, &ret)) {
for (Var v : vars) { return Array<Expr>();
LinearEqEntry ret;
if (!LinearEqDetector(v).Detect(base, &ret)) {
return Array<Expr>();
}
coeff.push_back(ret.coeff);
base = std::move(ret.base);
} }
coeff.push_back(ret.coeff);
base = std::move(ret.base);
}
std::unordered_set<const Variable*> vset; std::unordered_set<const Variable*> vset;
for (size_t i = vars.size(); i != 1; --i) { for (size_t i = vars.size(); i > 1; --i) {
vset.insert(vars[i - 1].get()); vset.insert(vars[i - 1].get());
// The previous coeff contains the variable // The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset)) { if (ExprUseVar(coeff[i - 2], vset)) {
return Array<Expr>(); return Array<Expr>();
}
} }
} }
coeff.push_back(base); coeff.push_back(base);
......
...@@ -39,7 +39,6 @@ class CopyIntrinInjector : public IRMutator { ...@@ -39,7 +39,6 @@ class CopyIntrinInjector : public IRMutator {
bool MatchCopyPattern(Stmt stmt, Stmt *out) { bool MatchCopyPattern(Stmt stmt, Stmt *out) {
using namespace arith; using namespace arith;
Stmt body = stmt; Stmt body = stmt;
bool is_single_point_copy = false;
// strip the loops // strip the loops
std::vector<const For*> loops; std::vector<const For*> loops;
...@@ -60,7 +59,6 @@ class CopyIntrinInjector : public IRMutator { ...@@ -60,7 +59,6 @@ class CopyIntrinInjector : public IRMutator {
const Cast* cast = store->value.as<Cast>(); const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>(); const Load* load = store->value.as<Load>();
if (0 == loops.size()) { if (0 == loops.size()) {
is_single_point_copy = true;
CHECK(!has_cond); CHECK(!has_cond);
} }
// for now only support true condition matching // for now only support true condition matching
...@@ -83,9 +81,8 @@ class CopyIntrinInjector : public IRMutator { ...@@ -83,9 +81,8 @@ class CopyIntrinInjector : public IRMutator {
arith::DetectLinearEquation(load->index, loop_vars); arith::DetectLinearEquation(load->index, loop_vars);
if (load_strides.size() == 0 || store_strides.size() == 0) return false; if (load_strides.size() == 0 || store_strides.size() == 0) return false;
Array<Expr> dst_shape; Array<Expr> dst_shape;
auto loop_var_size = loop_vars.size(); const size_t loop_var_size = loop_vars.size();
if (is_single_point_copy) { if (loop_var_size == 0) {
loop_var_size = 1;
dst_shape.push_back(make_const(Int(32), 1)); dst_shape.push_back(make_const(Int(32), 1));
} else { } else {
for (const For* op : loops) { for (const For* op : loops) {
...@@ -132,6 +129,10 @@ class CopyIntrinInjector : public IRMutator { ...@@ -132,6 +129,10 @@ class CopyIntrinInjector : public IRMutator {
CHECK_EQ(load_strides.size(), loop_var_size + 1); CHECK_EQ(load_strides.size(), loop_var_size + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
if (loop_var_size == 0) {
src_strides.push_back(make_const(Int(32), 1));
dst_strides.push_back(make_const(Int(32), 1));
}
Buffer dst = BufferNode::make( Buffer dst = BufferNode::make(
Var(store->buffer_var.node_), Var(store->buffer_var.node_),
store->value.type(), store->value.type(),
......
...@@ -20,6 +20,10 @@ def test_basic(): ...@@ -20,6 +20,10 @@ def test_basic():
m = tvm.arith.DetectLinearEquation(b * 7, [a]) m = tvm.arith.DetectLinearEquation(b * 7, [a])
assert m[0].value == 0 assert m[0].value == 0
m = tvm.arith.DetectLinearEquation(b * 7, [])
assert len(m) == 1
assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0
def test_multivariate(): def test_multivariate():
v = [tvm.var("v%d" % i) for i in range(4)] v = [tvm.var("v%d" % i) for i in range(4)]
b = tvm.var("b") b = tvm.var("b")
...@@ -42,6 +46,10 @@ def test_multivariate(): ...@@ -42,6 +46,10 @@ def test_multivariate():
assert(m[0].value == 0) assert(m[0].value == 0)
assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [])
assert(len(m) == 1)
assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0)
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_multivariate() test_multivariate()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment