Commit 2bb1d8e4 by Tianqi Chen Committed by GitHub

[ARITH] Upgrade CanonicalSimplify to Simplify Mod (#676)

parent 2e3f8e74
...@@ -41,16 +41,20 @@ Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>()); ...@@ -41,16 +41,20 @@ Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
/*! /*!
* \brief Simplify by applying canonical form. * \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed. * \param stmt The statement to be canonically simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement. * \return Canonicalized statement.
*/ */
Stmt CanonicalSimplify(Stmt stmt); Stmt CanonicalSimplify(Stmt stmt,
Map<Var, Range> vrange = Map<Var, Range>());
/*! /*!
* \brief Simplify by applying canonical form. * \brief Simplify by applying canonical form.
* \param expr The statement to be canonically simplifed. * \param expr The statement to be canonically simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized expression. * \return Canonicalized expression.
*/ */
Expr CanonicalSimplify(Expr expr); Expr CanonicalSimplify(Expr expr,
Map<Var, Range> vrange = Map<Var, Range>());
/*! /*!
* \brief Deep compare lhs and rhs * \brief Deep compare lhs and rhs
......
...@@ -33,10 +33,18 @@ TVM_REGISTER_API("ir_pass.Simplify") ...@@ -33,10 +33,18 @@ TVM_REGISTER_API("ir_pass.Simplify")
TVM_REGISTER_API("ir_pass.CanonicalSimplify") TVM_REGISTER_API("ir_pass.CanonicalSimplify")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) { if (args[0].IsNodeType<Stmt>()) {
if (args.size() > 1) {
*ret = CanonicalSimplify(args[0].operator Stmt(), args[1]);
} else {
*ret = CanonicalSimplify(args[0].operator Stmt()); *ret = CanonicalSimplify(args[0].operator Stmt());
}
} else {
if (args.size() > 1) {
*ret = CanonicalSimplify(args[0].operator Expr(), args[1]);
} else { } else {
*ret = CanonicalSimplify(args[0].operator Expr()); *ret = CanonicalSimplify(args[0].operator Expr());
} }
}
}); });
TVM_REGISTER_API("ir_pass.Equal") TVM_REGISTER_API("ir_pass.Equal")
......
...@@ -129,6 +129,11 @@ inline Expr Binary_(const T* op, ...@@ -129,6 +129,11 @@ inline Expr Binary_(const T* op,
// internal of canonical engine. // internal of canonical engine.
class Canonical::Internal : public IRMutator { class Canonical::Internal : public IRMutator {
public: public:
explicit Internal(Map<Var, Range> vrange) {
for (auto kv : vrange) {
SetRange(kv.first, kv.second, 0);
}
}
// stack entry. // stack entry.
struct StackEntry { struct StackEntry {
int max_level{0}; int max_level{0};
...@@ -300,9 +305,25 @@ class Canonical::Internal : public IRMutator { ...@@ -300,9 +305,25 @@ class Canonical::Internal : public IRMutator {
Expr Mutate_(const Div* op, const Expr& e) final { Expr Mutate_(const Div* op, const Expr& e) final {
return Binary(op, e); return Binary(op, e);
} }
// Mod operator
Expr Mutate_(const Mod* op, const Expr& e) final { Expr Mutate_(const Mod* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e); return Binary(op, e);
} }
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
if (is_const(a.value) && is_const(b.value)) {
return ComputeExpr<Mul>(a.value, b.value);
} else if (is_const(b.value)) {
return SumModConst(a.AsSum(), b.value);
} else {
return Binary(op, e);
}
}
Expr Mutate_(const And* op, const Expr& e) final { Expr Mutate_(const And* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<And>(); op = expr.as<And>();
...@@ -367,7 +388,7 @@ class Canonical::Internal : public IRMutator { ...@@ -367,7 +388,7 @@ class Canonical::Internal : public IRMutator {
private: private:
template<typename T> template<typename T>
Expr Binary(const T* op, const Expr& e) { Expr Binary(const T* op, Expr e) {
Expr a = this->Mutate(op->a); Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b); Expr b = this->Mutate(op->b);
BinaryExpr key{static_cast<int>(T::_type_info), a, b}; BinaryExpr key{static_cast<int>(T::_type_info), a, b};
...@@ -398,8 +419,8 @@ class Canonical::Internal : public IRMutator { ...@@ -398,8 +419,8 @@ class Canonical::Internal : public IRMutator {
std::vector<Var> var_rec_; std::vector<Var> var_rec_;
// level counter // level counter
int level_counter_{0}; int level_counter_{0};
// subroutine to do produce // get constant int value
Expr SumMulConst(ComExpr a, Expr v) { int64_t GetConstIntValue(const Expr& v) {
int64_t value = 0; int64_t value = 0;
const int64_t *v1 = as_const_int(v); const int64_t *v1 = as_const_int(v);
const uint64_t *v2 = as_const_uint(v); const uint64_t *v2 = as_const_uint(v);
...@@ -411,7 +432,45 @@ class Canonical::Internal : public IRMutator { ...@@ -411,7 +432,45 @@ class Canonical::Internal : public IRMutator {
static_cast<uint64_t>(std::numeric_limits<int64_t>::max())); static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
value = static_cast<int64_t>(*v2); value = static_cast<int64_t>(*v2);
} }
return value;
}
// subroutine to do produce a % v
Expr SumModConst(ComExpr a, Expr v) {
int64_t value = GetConstIntValue(v);
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
int mod_level = 0;
n->base = a->base % value;
if (n->base != 0) mod_level = 1;
for (auto e : a->elem) {
if (e.scale % value == 0) continue;
e.scale = e.scale % value;
if (!EvalSet(v - e.value, var_range_).can_prove_positive()) {
mod_level = 2;
} else {
++mod_level;
}
n->elem.push_back(e);
}
// cannot remove mode because there are more than two parts
if (mod_level >= 2) {
Expr ret = Sum2Expr(ComExpr(n), v.type()) % v;
return Binary(ret.as<Mod>(), ret);
}
ret_entry_.sum = ComExpr(n);
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum);
if (it != cache_sum_.end()) {
ret_entry_ = it->second;
} else {
ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
}
return ret_entry_.value;
}
// subroutine to do produce
Expr SumMulConst(ComExpr a, Expr v) {
int64_t value = GetConstIntValue(v);
if (value == 0) { if (value == 0) {
return make_zero(v.type()); return make_zero(v.type());
} }
...@@ -421,9 +480,9 @@ class Canonical::Internal : public IRMutator { ...@@ -421,9 +480,9 @@ class Canonical::Internal : public IRMutator {
for (auto& e : vsum->elem) { for (auto& e : vsum->elem) {
e.scale *= value; e.scale *= value;
} }
ret_entry_.sum = ComExpr(vsum);
ret_entry_.max_level = stack_.back().max_level; ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect; ret_entry_.has_side_effect = stack_.back().has_side_effect;
ret_entry_.sum = ComExpr(vsum);
auto it = cache_sum_.find(ret_entry_.sum); auto it = cache_sum_.find(ret_entry_.sum);
if (it != cache_sum_.end()) { if (it != cache_sum_.end()) {
ret_entry_ = it->second; ret_entry_ = it->second;
...@@ -536,8 +595,8 @@ class Canonical::Internal : public IRMutator { ...@@ -536,8 +595,8 @@ class Canonical::Internal : public IRMutator {
using CInternal = Canonical::Internal; using CInternal = Canonical::Internal;
Canonical::Canonical() Canonical::Canonical(Map<Var, Range> vrange)
: ptr_(std::make_shared<Internal>()) {} : ptr_(std::make_shared<Internal>(vrange)) {}
Expr Canonical::Simplify(Expr expr) { Expr Canonical::Simplify(Expr expr) {
return ptr_->Mutate(expr); return ptr_->Mutate(expr);
...@@ -553,12 +612,12 @@ void Canonical::SetRange(Var v, Range r, int level) { ...@@ -553,12 +612,12 @@ void Canonical::SetRange(Var v, Range r, int level) {
} // namespace arith } // namespace arith
namespace ir { namespace ir {
Stmt CanonicalSimplify(Stmt stmt) { Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::Canonical().Simplify(stmt); return arith::Canonical(vrange).Simplify(stmt);
} }
Expr CanonicalSimplify(Expr expr) { Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
return arith::Canonical().Simplify(expr); return arith::Canonical(vrange).Simplify(expr);
} }
template<typename T> template<typename T>
......
...@@ -22,7 +22,7 @@ namespace arith { ...@@ -22,7 +22,7 @@ namespace arith {
class Canonical { class Canonical {
public: public:
/*! \brief constructor */ /*! \brief constructor */
Canonical(); explicit Canonical(Map<Var, Range> var_range);
/*! /*!
* \brief simplify expression e. * \brief simplify expression e.
* \param expr The expression to be simplified. * \param expr The expression to be simplified.
......
...@@ -20,5 +20,23 @@ def test_simplify(): ...@@ -20,5 +20,23 @@ def test_simplify():
zz = zz.a zz = zz.a
assert zz.a == x and zz.b.value == 4 assert zz.a == x and zz.b.value == 4
def test_simplify_mod():
"""Not yet working, mock design"""
ib = tvm.ir_builder.create()
n = tvm.var('n')
j = tvm.var('j')
A = ib.pointer("float32", name="A")
with ib.for_range(0, 16, name="i") as i:
A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16]
body = ib.get()
stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16)
assert diff.value == 0
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6)})
assert index == j
if __name__ == "__main__": if __name__ == "__main__":
test_simplify_mod()
test_simplify() test_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment