Commit 1e578a6f by Sergei Grechanik Committed by Tianqi Chen

[TVM] Fix segfault for CanonicalSimplify(x % -1) (#2194)

parent 3b91ae28
...@@ -515,7 +515,15 @@ class Canonical::Internal : public IRMutator { ...@@ -515,7 +515,15 @@ class Canonical::Internal : public IRMutator {
n->elem.push_back(e); n->elem.push_back(e);
} }
Expr ret = Sum2Expr(ComExpr(n), v.type()) % v; Expr ret = Sum2Expr(ComExpr(n), v.type()) % v;
return Binary(ret.as<Mod>(), ret); if (const Mod* mod = ret.as<Mod>()) {
return Binary(mod, ret);
} else {
// Sometimes the result is a constant, this may happen when value is -1
CHECK(is_const(ret)) << "CanonicalSimplify: "
<< Sum2Expr(ComExpr(n), v.type()) << " % " << v << " is " << ret
<< " which is neither Mod, nor a constant";
return ret;
}
} }
ret_entry_.sum = pair[1]; ret_entry_.sum = pair[1];
ret_entry_.max_level = stack_.back().max_level; ret_entry_.max_level = stack_.back().max_level;
......
...@@ -20,6 +20,16 @@ def test_simplify(): ...@@ -20,6 +20,16 @@ def test_simplify():
zz = zz.a zz = zz.a
assert zz.a == x and zz.b.value == 4 assert zz.a == x and zz.b.value == 4
n = tvm.var('n')
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % (-1)), tvm.const(0))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / 1), n)
tvm.ir_pass.CanonicalSimplify(n / (-1))
# This is not true in the current implementation
# assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)),
# tvm.ir_pass.CanonicalSimplify(-n))
def test_simplify_mod(): def test_simplify_mod():
"""Not yet working, mock design""" """Not yet working, mock design"""
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment