Commit 2bb1d8e4 by Tianqi Chen Committed by GitHub

[ARITH] Upgrade CanonicalSimplify to Simplify Mod (#676)

parent 2e3f8e74
......@@ -41,16 +41,20 @@ Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Stmt CanonicalSimplify(Stmt stmt);
Stmt CanonicalSimplify(Stmt stmt,
Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Simplify by applying canonical form.
* \param expr The statement to be canonically simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized expression.
*/
Expr CanonicalSimplify(Expr expr);
Expr CanonicalSimplify(Expr expr,
Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Deep compare lhs and rhs
......
......@@ -33,9 +33,17 @@ TVM_REGISTER_API("ir_pass.Simplify")
TVM_REGISTER_API("ir_pass.CanonicalSimplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) {
*ret = CanonicalSimplify(args[0].operator Stmt());
if (args.size() > 1) {
*ret = CanonicalSimplify(args[0].operator Stmt(), args[1]);
} else {
*ret = CanonicalSimplify(args[0].operator Stmt());
}
} else {
*ret = CanonicalSimplify(args[0].operator Expr());
if (args.size() > 1) {
*ret = CanonicalSimplify(args[0].operator Expr(), args[1]);
} else {
*ret = CanonicalSimplify(args[0].operator Expr());
}
}
});
......
......@@ -129,6 +129,11 @@ inline Expr Binary_(const T* op,
// internal of canonical engine.
class Canonical::Internal : public IRMutator {
public:
explicit Internal(Map<Var, Range> vrange) {
for (auto kv : vrange) {
SetRange(kv.first, kv.second, 0);
}
}
// stack entry.
struct StackEntry {
int max_level{0};
......@@ -300,9 +305,25 @@ class Canonical::Internal : public IRMutator {
Expr Mutate_(const Div* op, const Expr& e) final {
return Binary(op, e);
}
// Mod operator
Expr Mutate_(const Mod* op, const Expr& e) final {
return Binary(op, e);
if (!EnableOpt(op->type)) {
return Binary(op, e);
}
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
if (is_const(a.value) && is_const(b.value)) {
return ComputeExpr<Mul>(a.value, b.value);
} else if (is_const(b.value)) {
return SumModConst(a.AsSum(), b.value);
} else {
return Binary(op, e);
}
}
Expr Mutate_(const And* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<And>();
......@@ -367,7 +388,7 @@ class Canonical::Internal : public IRMutator {
private:
template<typename T>
Expr Binary(const T* op, const Expr& e) {
Expr Binary(const T* op, Expr e) {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
BinaryExpr key{static_cast<int>(T::_type_info), a, b};
......@@ -398,8 +419,8 @@ class Canonical::Internal : public IRMutator {
std::vector<Var> var_rec_;
// level counter
int level_counter_{0};
// subroutine to do produce
Expr SumMulConst(ComExpr a, Expr v) {
// get constant int value
int64_t GetConstIntValue(const Expr& v) {
int64_t value = 0;
const int64_t *v1 = as_const_int(v);
const uint64_t *v2 = as_const_uint(v);
......@@ -411,7 +432,45 @@ class Canonical::Internal : public IRMutator {
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
value = static_cast<int64_t>(*v2);
}
return value;
}
// subroutine to do produce a % v
Expr SumModConst(ComExpr a, Expr v) {
int64_t value = GetConstIntValue(v);
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
int mod_level = 0;
n->base = a->base % value;
if (n->base != 0) mod_level = 1;
for (auto e : a->elem) {
if (e.scale % value == 0) continue;
e.scale = e.scale % value;
if (!EvalSet(v - e.value, var_range_).can_prove_positive()) {
mod_level = 2;
} else {
++mod_level;
}
n->elem.push_back(e);
}
// cannot remove mode because there are more than two parts
if (mod_level >= 2) {
Expr ret = Sum2Expr(ComExpr(n), v.type()) % v;
return Binary(ret.as<Mod>(), ret);
}
ret_entry_.sum = ComExpr(n);
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum);
if (it != cache_sum_.end()) {
ret_entry_ = it->second;
} else {
ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
}
return ret_entry_.value;
}
// subroutine to do produce
Expr SumMulConst(ComExpr a, Expr v) {
int64_t value = GetConstIntValue(v);
if (value == 0) {
return make_zero(v.type());
}
......@@ -421,9 +480,9 @@ class Canonical::Internal : public IRMutator {
for (auto& e : vsum->elem) {
e.scale *= value;
}
ret_entry_.sum = ComExpr(vsum);
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
ret_entry_.sum = ComExpr(vsum);
auto it = cache_sum_.find(ret_entry_.sum);
if (it != cache_sum_.end()) {
ret_entry_ = it->second;
......@@ -536,8 +595,8 @@ class Canonical::Internal : public IRMutator {
using CInternal = Canonical::Internal;
Canonical::Canonical()
: ptr_(std::make_shared<Internal>()) {}
Canonical::Canonical(Map<Var, Range> vrange)
: ptr_(std::make_shared<Internal>(vrange)) {}
Expr Canonical::Simplify(Expr expr) {
return ptr_->Mutate(expr);
......@@ -553,12 +612,12 @@ void Canonical::SetRange(Var v, Range r, int level) {
} // namespace arith
namespace ir {
Stmt CanonicalSimplify(Stmt stmt) {
return arith::Canonical().Simplify(stmt);
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::Canonical(vrange).Simplify(stmt);
}
Expr CanonicalSimplify(Expr expr) {
return arith::Canonical().Simplify(expr);
Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
return arith::Canonical(vrange).Simplify(expr);
}
template<typename T>
......
......@@ -22,7 +22,7 @@ namespace arith {
class Canonical {
public:
/*! \brief constructor */
Canonical();
explicit Canonical(Map<Var, Range> var_range);
/*!
* \brief simplify expression e.
* \param expr The expression to be simplified.
......
......@@ -20,5 +20,23 @@ def test_simplify():
zz = zz.a
assert zz.a == x and zz.b.value == 4
def test_simplify_mod():
"""Not yet working, mock design"""
ib = tvm.ir_builder.create()
n = tvm.var('n')
j = tvm.var('j')
A = ib.pointer("float32", name="A")
with ib.for_range(0, 16, name="i") as i:
A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16]
body = ib.get()
stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16)
assert diff.value == 0
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6)})
assert index == j
if __name__ == "__main__":
test_simplify_mod()
test_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment