Unverified Commit e6da0f0b by Tianqi Chen Committed by GitHub

[ARITH] CanonicalSimplifier, better folding, eliminate store. (#3464)

parent 75c29c6a
...@@ -256,9 +256,14 @@ class SumExprNode : public CanonicalExprNode { ...@@ -256,9 +256,14 @@ class SumExprNode : public CanonicalExprNode {
SplitExpr& rhs = args[j]; SplitExpr& rhs = args[j];
if (!lhs->IndexEqual(rhs)) break; if (!lhs->IndexEqual(rhs)) break;
if (lhs->upper_factor < rhs->lower_factor) break; if (lhs->upper_factor < rhs->lower_factor) break;
if (lhs->lower_factor == rhs->upper_factor && if (lhs->upper_factor == rhs->upper_factor &&
lhs->scale % rhs->scale == 0 && lhs->lower_factor == rhs->lower_factor) {
lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor) { // folding same co-efficient.
rhs.CopyOnWrite()->scale += lhs->scale;
lhs.CopyOnWrite()->scale = 0;
} else if (lhs->lower_factor == rhs->upper_factor &&
lhs->scale % rhs->scale == 0 &&
lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor) {
// Rules used in the proof: // Rules used in the proof:
// //
// Rule 1: (x % (c * s)) / c = (x / c) % s // Rule 1: (x % (c * s)) / c = (x / c) % s
......
...@@ -106,6 +106,19 @@ class StmtSimplifier : public IRMutator { ...@@ -106,6 +106,19 @@ class StmtSimplifier : public IRMutator {
} }
} }
// eliminate useless stores
Stmt Mutate_(const Store* op, const Stmt& s) {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
if (const Load* load = op->value.as<Load>()) {
if (load->buffer_var.same_as(op->buffer_var) &&
Equal(load->index, op->index)) {
return Evaluate::make(0);
}
}
return stmt;
}
protected: protected:
Analyzer analyzer_; Analyzer analyzer_;
// variable domain // variable domain
......
...@@ -187,6 +187,18 @@ def test_simplify_if_then_else(): ...@@ -187,6 +187,18 @@ def test_simplify_if_then_else():
ck.verify(res, 0) ck.verify(res, 0)
def test_complex_cases():
ck = CanonicalChecker()
x = tvm.var("x")
y = tvm.var("y")
res2 = (((((((((((x*128) + y) % 1296)/36)*2) + 1)/2)*36) +
((((((x*128) + y) % 36)*2) + 1)/2))
- (((x*128) + y) % 1296)) + 1)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5))
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127))
ck.verify(res2, 1)
if __name__ == "__main__": if __name__ == "__main__":
test_simplify_if_then_else() test_simplify_if_then_else()
test_div_simplify() test_div_simplify()
...@@ -195,3 +207,4 @@ if __name__ == "__main__": ...@@ -195,3 +207,4 @@ if __name__ == "__main__":
test_mul_sum_simplify() test_mul_sum_simplify()
test_split_index_simplify() test_split_index_simplify()
test_canonical_mixed() test_canonical_mixed()
test_complex_cases()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment