Commit 74ea8e5f by Sergei Grechanik Committed by Tianqi Chen

[TVM] Fix negating undefined in DetectLinearEquation (#1816)

parent 836cf13a
......@@ -111,8 +111,9 @@ class LinearEqDetector
return ComputeExpr<Add>(a, b);
}
Expr SubCombine(Expr a, Expr b) {
if (!a.defined()) return -b;
// Check b first in case they are both undefined
if (!b.defined()) return a;
if (!a.defined()) return -b;
return ComputeExpr<Sub>(a, b);
}
Expr MulCombine(Expr a, Expr b) {
......
......@@ -38,6 +38,10 @@ def test_multivariate():
assert(m[2].value == 2)
assert(m[len(m)-1].value == 2)
m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [v[2]])
assert(m[0].value == 0)
assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
if __name__ == "__main__":
test_basic()
test_multivariate()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment