Commit 74ea8e5f by Sergei Grechanik Committed by Tianqi Chen

[TVM] Fix negating undefined in DetectLinearEquation (#1816)

parent 836cf13a
...@@ -111,8 +111,9 @@ class LinearEqDetector ...@@ -111,8 +111,9 @@ class LinearEqDetector
return ComputeExpr<Add>(a, b); return ComputeExpr<Add>(a, b);
} }
Expr SubCombine(Expr a, Expr b) { Expr SubCombine(Expr a, Expr b) {
if (!a.defined()) return -b; // Check b first in case they are both undefined
if (!b.defined()) return a; if (!b.defined()) return a;
if (!a.defined()) return -b;
return ComputeExpr<Sub>(a, b); return ComputeExpr<Sub>(a, b);
} }
Expr MulCombine(Expr a, Expr b) { Expr MulCombine(Expr a, Expr b) {
......
...@@ -38,6 +38,10 @@ def test_multivariate(): ...@@ -38,6 +38,10 @@ def test_multivariate():
assert(m[2].value == 2) assert(m[2].value == 2)
assert(m[len(m)-1].value == 2) assert(m[len(m)-1].value == 2)
m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [v[2]])
assert(m[0].value == 0)
assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_multivariate() test_multivariate()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment