Commit bbc5d153 by Sergei Grechanik Committed by Tianqi Chen

[ARITH] Simplify casts of constants 0 and 1 (#3758)

* [ARITH] Simplify casts of constants 0 and 1

* [EXPR] is_const_value to check whether non-ints are consts

* Revert "[EXPR] is_const_value to check whether non-ints are consts"

This reverts commit 7e1b3462e3f74fd0afb1541d72978107cfa23c30.

* Use tvm::cast
parent 8bd9d4d5
...@@ -1757,6 +1757,13 @@ Mutate_(const Variable* op, const Expr& self) { ...@@ -1757,6 +1757,13 @@ Mutate_(const Variable* op, const Expr& self) {
return self; return self;
} }
Expr RewriteSimplifier::Impl::
Mutate_(const Cast* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Cast>();
return cast(op->type, op->value);
}
Expr RewriteSimplifier::operator()(const Expr& expr) { Expr RewriteSimplifier::operator()(const Expr& expr) {
// Run simplification in post order // Run simplification in post order
Expr res = expr; Expr res = expr;
......
...@@ -70,6 +70,7 @@ class RewriteSimplifier::Impl : public IRMutator { ...@@ -70,6 +70,7 @@ class RewriteSimplifier::Impl : public IRMutator {
Expr Mutate_(const Call* op, const Expr& self) override; Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override; Expr Mutate_(const Let* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override; Expr Mutate_(const Variable* op, const Expr& self) override;
Expr Mutate_(const Cast* op, const Expr& self) override;
protected: protected:
/*! \brief internal structure for comparison. */ /*! \brief internal structure for comparison. */
......
...@@ -105,12 +105,15 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) { ...@@ -105,12 +105,15 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) {
Expr cast(const Type& t, Expr value) { Expr cast(const Type& t, Expr value) {
using ir::IntImm; using ir::IntImm;
using ir::UIntImm;
using ir::FloatImm; using ir::FloatImm;
if (value.type() == t) return value; if (value.type() == t) return value;
// const fold IntImm as they are used in index computations // const fold IntImm as they are used in index computations
if (t.lanes() == 1) { if (t.lanes() == 1) {
if (const IntImm* op = value.as<IntImm>()) { if (const IntImm* op = value.as<IntImm>()) {
return make_const(t, op->value); return make_const(t, op->value);
} else if (const UIntImm* op = value.as<UIntImm>()) {
return make_const(t, op->value);
} else if (const FloatImm* op = value.as<FloatImm>()) { } else if (const FloatImm* op = value.as<FloatImm>()) {
return make_const(t, op->value); return make_const(t, op->value);
} }
...@@ -122,6 +125,8 @@ Expr cast(const Type& t, Expr value) { ...@@ -122,6 +125,8 @@ Expr cast(const Type& t, Expr value) {
if (value.type() != vtype) { if (value.type() != vtype) {
if (const IntImm* op = value.as<IntImm>()) { if (const IntImm* op = value.as<IntImm>()) {
value = make_const(vtype, op->value); value = make_const(vtype, op->value);
} else if (const UIntImm* op = value.as<UIntImm>()) {
return make_const(t, op->value);
} else if (const FloatImm* op = value.as<FloatImm>()) { } else if (const FloatImm* op = value.as<FloatImm>()) {
value = make_const(vtype, op->value); value = make_const(vtype, op->value);
} else { } else {
......
...@@ -804,6 +804,18 @@ def test_let_simplify(): ...@@ -804,6 +804,18 @@ def test_let_simplify():
z = tvm.expr.Let(x, 1, x + 1) z = tvm.expr.Let(x, 1, x + 1)
ck.verify(z + z, 4) ck.verify(z + z, 4)
def test_cast_simplify():
ck = RewriteChecker()
x = tvm.var("x")
dtypes = ["float32", "float16", "int32", "int8", "bool"]
for dtype1 in dtypes:
ck.verify(tvm.expr.Cast(dtype1, x - x), tvm.const(0, dtype1))
ck.verify(tvm.expr.Cast(dtype1, x == x), tvm.const(1, dtype1))
for dtype2 in dtypes:
for i in [0, 1, 2, 3]:
ck.verify(tvm.expr.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1))
if __name__ == "__main__": if __name__ == "__main__":
test_floordiv_index_simplify() test_floordiv_index_simplify()
test_floormod_index_simplify() test_floormod_index_simplify()
...@@ -819,3 +831,4 @@ if __name__ == "__main__": ...@@ -819,3 +831,4 @@ if __name__ == "__main__":
test_select_simplify() test_select_simplify()
test_logical_simplify() test_logical_simplify()
test_let_simplify() test_let_simplify()
test_cast_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment