Commit a31f6158 by Salem Derisavi Committed by Tianqi Chen

removed non-determinism from CanonicalSimplify (#704)

* 1) removed non-determinism from CanonicalSimplify
2) added couple of testcases for CanonicalSimplify

* Use IRDeepCompare instead of comparison of string representation

* Give a warning (instead of fatal error) when two "ComExprEntry"s are equal
parent bb97938d
...@@ -29,9 +29,17 @@ struct ComExprEntry { ...@@ -29,9 +29,17 @@ struct ComExprEntry {
inline bool operator<(const ComExprEntry& other) const { inline bool operator<(const ComExprEntry& other) const {
if (level < other.level) return true; if (level < other.level) return true;
if (level > other.level) return false; if (level > other.level) return false;
// compare top operator of entries and sort on that if possible (fast check)
if (value.type_index() < other.value.type_index()) return true; if (value.type_index() < other.value.type_index()) return true;
if (value.type_index() > other.value.type_index()) return false; if (value.type_index() > other.value.type_index()) return false;
return value.get() < other.value.get(); // if none of the above distinguishes the terms, compare the expression tree of the entries.
// This is a slower check.
int compare_result = Compare(value, other.value);
if (compare_result < 0) return true;
if (compare_result > 0) return false;
// it's a problem if we see identical entries at this point. They should've been merged earlier.
LOG(WARNING) << "we should not have identical entries at this point";
return false;
} }
}; };
......
...@@ -43,6 +43,16 @@ def test_canonical(): ...@@ -43,6 +43,16 @@ def test_canonical():
ret = tvm.ir_pass.CanonicalSimplify(x / (z+z) - x / (z+z)) ret = tvm.ir_pass.CanonicalSimplify(x / (z+z) - x / (z+z))
assert(tvm.ir_pass.Equal(ret, 0)) assert(tvm.ir_pass.Equal(ret, 0))
#make sure terms are ordered based on their top operators (e.g., / always precedes %)
ret1 = tvm.ir_pass.CanonicalSimplify(x % 3 + x / 3)
ret2 = tvm.ir_pass.CanonicalSimplify(x / 3 + x % 3)
assert(tvm.ir_pass.Equal(ret1, ret2))
#when top operators match, compare string representation of terms
ret1 = tvm.ir_pass.CanonicalSimplify(x % 4 + x % 3)
ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4)
assert (tvm.ir_pass.Equal(ret1, ret2))
if __name__ == "__main__": if __name__ == "__main__":
test_bound() test_bound()
test_basic() test_basic()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment