Unverified Commit e5efc632 by Tianqi Chen Committed by GitHub

[ARITH] Simplify let (#3568)

parent d82db909
......@@ -73,6 +73,15 @@ struct ModularSetAnalyzer::Entry {
bool is_const() const {
return coeff == 0;
}
bool operator==(const Entry& other) const {
return coeff == other.coeff && base == other.base;
}
bool operator==(const ModularSet& other) const {
return other.defined() &&
coeff == other->coeff && base == other->base;
}
};
class ModularSetAnalyzer::Impl :
......@@ -85,7 +94,14 @@ class ModularSetAnalyzer::Impl :
const ModularSet& info,
bool override) {
if (!override) {
CHECK(!var_map_.count(var));
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info)
<< "Trying to update var \'" << var << "\'"
<< " with a different const bound: "
<< "original=" << ModularSet(it->second.coeff, it->second.base)
<< ", new=" << info;
}
}
var_map_[var] = Entry(info->coeff, info->base);
}
......
......@@ -1730,6 +1730,33 @@ Mutate_(const Call* op, const Expr& self) {
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Let* op, const Expr& self) {
// For now assume value does not has side-effect.
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
parent_->Bind(op->var, value);
return this->Mutate(op->body);
}
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
} else {
return Let::make(op->var, value, body);
}
}
Expr RewriteSimplifier::Impl::
Mutate_(const Variable* op, const Expr& self) {
Var var = GetRef<Var>(op);
auto it = var_map_.find(var);
if (it != var_map_.end()) {
return it->second;
}
return self;
}
Expr RewriteSimplifier::operator()(const Expr& expr) {
// Run simplification in post order
Expr res = expr;
......
......@@ -68,6 +68,8 @@ class RewriteSimplifier::Impl : public IRMutator {
Expr Mutate_(const Not* op, const Expr& self) override;
Expr Mutate_(const Select* op, const Expr& self) override;
Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override;
protected:
/*! \brief internal structure for comparison. */
......
......@@ -50,11 +50,26 @@ class StmtSimplifier : public IRMutator {
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Var loop_var(op->loop_var.node_);
analyzer_.Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_.Bind(op->var, value);
return this->Mutate(op->body);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
......
......@@ -798,6 +798,11 @@ def test_logical_simplify():
ck.verify(tvm.expr.Or(2 <= x, x <= 1), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(x != 1, x == 2), x != 1)
def test_let_simplify():
ck = RewriteChecker()
x, y = tvm.var("x"), tvm.var("y")
z = tvm.expr.Let(x, 1, x + 1)
ck.verify(z + z, 4)
if __name__ == "__main__":
test_floordiv_index_simplify()
......@@ -813,3 +818,4 @@ if __name__ == "__main__":
test_mod_index_simplify()
test_select_simplify()
test_logical_simplify()
test_let_simplify()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
def test_stmt_simplify():
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="C")
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.if_scope(i < 12):
A[i] = C[i]
body = tvm.stmt.LetStmt(n, 10, ib.get())
body = tvm.ir_pass.CanonicalSimplify(body)
assert isinstance(body.body, tvm.stmt.Store)
if __name__ == "__main__":
test_stmt_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment