From 0c38b9161bcfad60c2022682b4ab4a7d105b3148 Mon Sep 17 00:00:00 2001
From: ANSHUMAN TRIPATHY <anshuman.t@huawei.com>
Date: Thu, 26 Mar 2020 03:16:18 +0530
Subject: [PATCH] Duplicate likely nodes added when loop axis split unevenly (#5084)

* [TE][Schedule] Duplicate likely nodes removed

* [1] Test case added

* [2] Lint error fixed

* [3] Review comments handled

* [4] Review comments handled
---
 src/te/schedule/message_passing.cc           | 10 +++++++++-
 tests/python/unittest/test_te_build_lower.py | 13 +++++++++++++
 2 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc
index a7b2482..fc13278 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -556,6 +556,14 @@ void PassUpBoundCheck(const Stage& s,
   }
 }
 
+bool IsRangeSame(const Range input_1, const Range input_2) {
+  arith::Analyzer analyzer;
+  if (input_1.same_as(input_2)) return true;
+
+  return (analyzer.CanProve(input_1->min == input_2->min)
+        && analyzer.CanProve(input_1->extent == input_2->extent));
+}
+
 std::vector<PrimExpr> MakeBoundCheck(
     const Stage& stage,
     const Map<IterVar, Range>& dom_map,
@@ -593,7 +601,7 @@ std::vector<PrimExpr> MakeBoundCheck(
     if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
     Range dom = dom_map.at(iv);
     CHECK(iv->dom.defined());
-    if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
+    if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
       PrimExpr value = value_map.at(iv) - iv->dom->min;
       IntSet s = EvalSet(value, iset_dmap);
       PrimExpr vmin = s.min();
diff --git a/tests/python/unittest/test_te_build_lower.py b/tests/python/unittest/test_te_build_lower.py
index 736030b..3ad1747 100644
--- a/tests/python/unittest/test_te_build_lower.py
+++ b/tests/python/unittest/test_te_build_lower.py
@@ -40,6 +40,19 @@ def test_dependent_output_shape():
     s = te.create_schedule(B.op)
     mod = tvm.build(s, [A, B, x])
 
+def test_split_uneven_unique_likely():
+    a = te.placeholder((16, 16),)
+    b = te.placeholder((16, 16),)
+    c = te.compute((16, 16), lambda x, y: a[x, y] + b[x, y])
+
+    x, y = c.op.axis
+    sch = te.create_schedule(c.op)
+    xo, xi = sch[c].split(x, 5)
+    stmt = tvm.lower(sch, [a, b, c], simple_mode=True)
+    assert isinstance(stmt.body.body.body.body, tvm.tir.stmt.IfThenElse)
+    assert str(stmt.body.body.body.body).count("likely") == 1
+
 if __name__ == "__main__":
     test_lower_rfactor()
     test_dependent_output_shape()
+    test_split_uneven_unique_likely()
--
libgit2 0.26.0