Commit 515d4b6f by Tianqi Chen Committed by GitHub

[PASS] More simplifier for mod and div (#1100)

* [PASS] More simplifier for mod and div

* fix testcase
parent 1f0ca085
...@@ -312,9 +312,23 @@ class Canonical::Internal : public IRMutator { ...@@ -312,9 +312,23 @@ class Canonical::Internal : public IRMutator {
return e; return e;
} }
} }
// binary ops // Div operator
Expr Mutate_(const Div* op, const Expr& e) final { Expr Mutate_(const Div* op, const Expr& e) final {
return Binary(op, e); if (!EnableOpt(op->type)) {
return Binary(op, e);
}
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
if (is_const(a.value) && is_const(b.value)) {
return ComputeExpr<Div>(a.value, b.value);
} else if (is_const(b.value)) {
return SumDivConst(a.AsSum(), b.value);
} else {
return Binary(op, e);
}
} }
// Mod operator // Mod operator
Expr Mutate_(const Mod* op, const Expr& e) final { Expr Mutate_(const Mod* op, const Expr& e) final {
...@@ -445,29 +459,80 @@ class Canonical::Internal : public IRMutator { ...@@ -445,29 +459,80 @@ class Canonical::Internal : public IRMutator {
} }
return value; return value;
} }
// subroutine to do produce a % v // Detect if a = x * coeff + y, where y \in [0, coeff), x >= 0
Expr SumModConst(ComExpr a, Expr v) { // return true if such detection is successful
int64_t value = GetConstIntValue(v); // return false if it is not.
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); std::vector<ComExpr> TryLinearEquation(const ComExpr& a,
int mod_level = 0; const Expr& coeff) {
n->base = a->base % value; Type type = coeff.type();
if (n->base != 0) mod_level = 1; int64_t value = GetConstIntValue(coeff);
for (auto e : a->elem) { if (value < 0) return {};
if (e.scale % value == 0) continue; std::shared_ptr<ComExprNode> xnode = std::make_shared<ComExprNode>();
e.scale = e.scale % value; std::shared_ptr<ComExprNode> ynode = std::make_shared<ComExprNode>();
if (!EvalSet(v - e.value, var_range_).can_prove_positive()) { if (a->base % value == 0) {
mod_level = 2; xnode->base = a->base;
} else {
ynode->base = a->base;
}
for (const auto& e : a->elem) {
if (e.scale % value == 0) {
xnode->elem.push_back(e);
} else { } else {
++mod_level; ynode->elem.push_back(e);
} }
n->elem.push_back(e);
} }
// cannot remove mode because there are more than two parts Expr yres = Sum2Expr(ComExpr(ynode), type);
if (mod_level >= 2) { IntSet yset = EvalSet(yres, var_range_);
// This relies on the integer division rounds down
// Most cases it is good for integer division.
if (yset.min().type() == type &&
can_prove(yset.min() >= make_zero(type)) &&
yset.max().type() == type &&
can_prove(yset.max() < coeff)) {
xnode->base /= value;
for (auto &e : xnode->elem) {
e.scale /= value;
}
return {ComExpr(xnode), ComExpr(ynode)};
} else {
return {};
}
}
// subroutine to do produce a % v
Expr SumModConst(ComExpr a, Expr v) {
std::vector<ComExpr> pair = TryLinearEquation(a, v);
if (pair.size() == 0) {
int64_t value = GetConstIntValue(v);
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
n->base = a->base % value;
for (auto e : a->elem) {
if (e.scale % value == 0) continue;
e.scale = e.scale % value;
n->elem.push_back(e);
}
Expr ret = Sum2Expr(ComExpr(n), v.type()) % v; Expr ret = Sum2Expr(ComExpr(n), v.type()) % v;
return Binary(ret.as<Mod>(), ret); return Binary(ret.as<Mod>(), ret);
} }
ret_entry_.sum = ComExpr(n); ret_entry_.sum = pair[1];
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum);
if (it != cache_sum_.end()) {
ret_entry_ = it->second;
} else {
ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
}
return ret_entry_.value;
}
// subroutine to do produce a % v
Expr SumDivConst(ComExpr a, Expr v) {
std::vector<ComExpr> pair = TryLinearEquation(a, v);
if (pair.size() == 0) {
Expr ret = Sum2Expr(a, v.type()) / v;
return Binary(ret.as<Div>(), ret);
}
ret_entry_.sum = pair[0];
ret_entry_.max_level = stack_.back().max_level; ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect; ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum); auto it = cache_sum_.find(ret_entry_.sum);
......
...@@ -279,7 +279,9 @@ class WarpMemoryRewriter : private IRMutator { ...@@ -279,7 +279,9 @@ class WarpMemoryRewriter : private IRMutator {
Stmt Rewrite(Stmt stmt) { Stmt Rewrite(Stmt stmt) {
if (warp_size_ == 1) return stmt; if (warp_size_ == 1) return stmt;
return this->Mutate(stmt); stmt = this->Mutate(stmt);
stmt = CanonicalSimplify(stmt);
return stmt;
} }
private: private:
......
...@@ -37,6 +37,26 @@ def test_simplify_mod(): ...@@ -37,6 +37,26 @@ def test_simplify_mod():
assert index == j assert index == j
def test_modular():
rx = tvm.var("rx")
ry = tvm.var("ry")
y = tvm.var("y")
x = tvm.var("x")
vmap = {rx: tvm.Range(tvm.const(0), tvm.const(3)),
ry: tvm.Range(tvm.const(0), tvm.const(3)),
y: tvm.Range(tvm.const(0), tvm.const(2)),
x: tvm.Range(tvm.const(0), tvm.const(14))}
idx = ry * 16 + rx + y * 16 + x
z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap)
z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap)
assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
if __name__ == "__main__": if __name__ == "__main__":
test_simplify_mod() test_simplify_mod()
test_modular()
test_simplify() test_simplify()
...@@ -33,7 +33,6 @@ def test_bound(): ...@@ -33,7 +33,6 @@ def test_bound():
ret = tvm.ir_pass.Simplify(m % 10, vrange) ret = tvm.ir_pass.Simplify(m % 10, vrange)
assert ret == m assert ret == m
def test_canonical(): def test_canonical():
x = tvm.var("x") x = tvm.var("x")
z = tvm.const(3) z = tvm.const(3)
...@@ -54,6 +53,7 @@ def test_canonical(): ...@@ -54,6 +53,7 @@ def test_canonical():
assert (tvm.ir_pass.Equal(ret1, ret2)) assert (tvm.ir_pass.Equal(ret1, ret2))
if __name__ == "__main__": if __name__ == "__main__":
test_modular()
test_bound() test_bound()
test_basic() test_basic()
test_simplify() test_simplify()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment