Commit cb70da1b by Wei Chen Committed by Tianqi Chen

Improve CanonicalSimplify to handle Min, Max(#2248) (#2261)

Also enable Mul caching for more cases
parent 1cb602f1
......@@ -236,6 +236,24 @@ class Canonical::Internal : public IRMutator {
bool EnableOpt(Type t) const {
return (t.lanes() == 1 && (t.is_int() || t.is_uint()));
}
// Max
Expr Mutate_(const Max* op, const Expr& e) final {
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
return Binary(op, e);
}
// Min
Expr Mutate_(const Min* op, const Expr& e) final {
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
return Binary(op, e);
}
// Add
Expr Mutate_(const Add* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
......@@ -277,7 +295,7 @@ class Canonical::Internal : public IRMutator {
} else if (is_const(b.value)) {
return SumMulConst(a.AsSum(), b.value);
} else {
return Binary_(op, e, a.value, b.value);
return Binary(op, e);
}
}
// Variable
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_pass.h>
#include <tvm/tvm.h>
#include <arithmetic/Simplify.h>
......@@ -8,6 +9,24 @@ TEST(IRSIMPLIFY, Basic) {
simplify_test();
}
TEST(IRSIMPLIFY, MinMax) {
auto x = tvm::var("x");
auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)) ;
auto e1s = tvm::ir::CanonicalSimplify(e1);
CHECK(is_zero(e1s));
auto e2 = (x * tvm::min(x, 1)) - (x * tvm::min(x, 1));
auto e2s = tvm::ir::CanonicalSimplify(e2);
CHECK(is_zero(e2s));
}
TEST(IRSIMPLIFY, Mul) {
auto x = tvm::var("x");
auto e = (x * x) - (x * x) ;
auto es = tvm::ir::CanonicalSimplify(e);
CHECK(is_zero(es));
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
......@@ -46,6 +46,21 @@ def test_simplify_mod():
(j + n * 32) % 16, {j: tvm.Range(0, 6)})
assert index == j
def test_simplify_minmax():
x = tvm.var('x')
e1 = tvm.max(x, 1) - tvm.max(x, 1)
e1s = tvm.ir_pass.CanonicalSimplify(e1)
assert e1s.value == 0
e2 = tvm.min(x, 1) - tvm.min(x, 1)
e2s = tvm.ir_pass.CanonicalSimplify(e2)
assert e2s.value == 0
def test_mul():
x = tvm.var('x')
e = x * x - x * x
es = tvm.ir_pass.CanonicalSimplify(e)
assert es.value == 0
def test_modular():
rx = tvm.var("rx")
......@@ -62,11 +77,9 @@ def test_modular():
assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
if __name__ == "__main__":
test_simplify_mod()
test_modular()
test_simplify()
test_mul()
test_simplify_minmax()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment