Commit 859bda8a by Salem Derisavi Committed by Tianqi Chen

1) fixed constant folding for mod operation in CanonicalSimplify 2) added a unit test (#2487)

parent 12aca82e
...@@ -363,7 +363,7 @@ class Canonical::Internal : public IRMutator { ...@@ -363,7 +363,7 @@ class Canonical::Internal : public IRMutator {
return Binary_(op, e, a.value, b.value); return Binary_(op, e, a.value, b.value);
} }
if (is_const(a.value) && is_const(b.value)) { if (is_const(a.value) && is_const(b.value)) {
return ComputeExpr<Mul>(a.value, b.value); return ComputeExpr<Mod>(a.value, b.value);
} else if (is_const(b.value)) { } else if (is_const(b.value)) {
return SumModConst(a.AsSum(), b.value); return SumModConst(a.AsSum(), b.value);
} else { } else {
......
...@@ -27,6 +27,16 @@ TEST(IRSIMPLIFY, Mul) { ...@@ -27,6 +27,16 @@ TEST(IRSIMPLIFY, Mul) {
CHECK(is_zero(es)); CHECK(is_zero(es));
} }
TEST(IRSIMPLIFY, Mod) {
auto x = tvm::Integer(10);
auto y = tvm::Integer(12);
// Mod::make is used instead of % to avoid constant folding during
// calling operator%(x,y). Mod::make doesn't try constant folding,
// and therefore, the constant folding will be attempted in CanonicalSimplify
auto mod = tvm::ir::CanonicalSimplify(tvm::ir::Mod::make(x, y));
auto es = tvm::ir::CanonicalSimplify(mod - x);
CHECK(is_zero(es));
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment