Unverified Commit 92439166 by Tianqi Chen Committed by GitHub

[ARITH] Introduce base-class IRMutatorWithAnalyzer for scope dependent analysis (#3969)

parent cdbf4d85
...@@ -52,10 +52,10 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) { ...@@ -52,10 +52,10 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
void Analyzer::Bind(const VarExpr& v, const Range& range) { void Analyzer::Bind(const VarExpr& v, const Range& range) {
CHECK(range.defined()); CHECK(range.defined());
Var var(v.node_); Var var(v.node_);
this->const_int_bound.Bind(var, range);
if (is_one(range->extent)) { if (is_one(range->extent)) {
this->rewrite_simplify.Update(var, range->min); this->Bind(var, range->min);
this->canonical_simplify.Update(var, range->min); } else {
this->const_int_bound.Bind(var, range);
} }
// skip modular_set // skip modular_set
// skip rewrite simplify // skip rewrite simplify
......
...@@ -744,8 +744,8 @@ Mutate_(const Div* op, const Expr& self) { ...@@ -744,8 +744,8 @@ Mutate_(const Div* op, const Expr& self) {
return std::move(lhs); return std::move(lhs);
} }
// both lhs and extra are non-negative // both lhs and extra are non-negative
if (parent_->CanProveGreaterEqual(lhs->Normalize(), 0) && if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
parent_->CanProveGreaterEqual(extra->Normalize(), 0)) { analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
lhs.CopyOnWrite()->DivideBy(cval); lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra); Expr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImm>()) { if (const auto* pconst = temp.as<IntImm>()) {
...@@ -761,7 +761,7 @@ Mutate_(const Div* op, const Expr& self) { ...@@ -761,7 +761,7 @@ Mutate_(const Div* op, const Expr& self) {
} }
} else { } else {
// if a >= 0 && a < cval, then result == 0 // if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a)); auto cbound = analyzer_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) { if (cbound->min_value >= 0 && cbound->max_value < cval) {
return make_zero(a.type()); return make_zero(a.type());
} }
...@@ -809,7 +809,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { ...@@ -809,7 +809,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval)); lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
} else { } else {
// if 0 <= extra < cval, it means the extra can be eliminated. // if 0 <= extra < cval, it means the extra can be eliminated.
if (!(TryCompare(temp, cval) == kLT && parent_->CanProveGreaterEqual(temp, 0))) { if (!(TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0))) {
lhs.CopyOnWrite()->AddToSelf( lhs.CopyOnWrite()->AddToSelf(
SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1); SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1);
} }
...@@ -817,7 +817,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { ...@@ -817,7 +817,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
return std::move(lhs); return std::move(lhs);
} else { } else {
// if a >= 0 && a < cval, then result == 0 // if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a)); auto cbound = analyzer_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) { if (cbound->min_value >= 0 && cbound->max_value < cval) {
return make_zero(a.type()); return make_zero(a.type());
} }
...@@ -908,8 +908,8 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -908,8 +908,8 @@ Mutate_(const Mod* op, const Expr& self) {
return make_zero(a.type()); return make_zero(a.type());
} }
// both lhs and extra are non-negative // both lhs and extra are non-negative
if (parent_->CanProveGreaterEqual(lhs->Normalize(), 0) && if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
parent_->CanProveGreaterEqual(extra->Normalize(), 0)) { analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
Expr temp = Normalize(extra); Expr temp = Normalize(extra);
if (temp.as<IntImm>()) { if (temp.as<IntImm>()) {
return temp % c1.Eval(); return temp % c1.Eval();
...@@ -927,7 +927,7 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -927,7 +927,7 @@ Mutate_(const Mod* op, const Expr& self) {
} }
// Simplify the offset constant if necessary. // Simplify the offset constant if necessary.
// (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0 // (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0
auto cbound = parent_->const_int_bound(Normalize(a)); auto cbound = analyzer_->const_int_bound(Normalize(a));
int64_t new_base = psum->base % cval; int64_t new_base = psum->base % cval;
if (cbound->min_value >= 0 && if (cbound->min_value >= 0 &&
cbound->min_value - psum->base + new_base >= 0) { cbound->min_value - psum->base + new_base >= 0) {
...@@ -937,7 +937,7 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -937,7 +937,7 @@ Mutate_(const Mod* op, const Expr& self) {
} }
} else { } else {
// if a >= 0 && a < cval, then result == 0 // if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a)); auto cbound = analyzer_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) { if (cbound->min_value >= 0 && cbound->max_value < cval) {
return a; return a;
} }
...@@ -980,7 +980,7 @@ Mutate_(const FloorMod* op, const Expr& self) { ...@@ -980,7 +980,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
} else { } else {
// If temp < cval && temp >=0 then can remove the mod. // If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT && if (TryCompare(temp, cval) == kLT &&
parent_->CanProveGreaterEqual(temp, 0)) { analyzer_->CanProveGreaterEqual(temp, 0)) {
return temp; return temp;
} else { } else {
// contonue to use logic below. // contonue to use logic below.
...@@ -997,7 +997,7 @@ Mutate_(const FloorMod* op, const Expr& self) { ...@@ -997,7 +997,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv); return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv);
} else { } else {
// if a >= 0 && a < cval, then result == a // if a >= 0 && a < cval, then result == a
auto cbound = parent_->const_int_bound(Normalize(a)); auto cbound = analyzer_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) { if (cbound->min_value >= 0 && cbound->max_value < cval) {
return a; return a;
} }
...@@ -1087,12 +1087,8 @@ SimplifyReduceCombiner(const Reduce* op) { ...@@ -1087,12 +1087,8 @@ SimplifyReduceCombiner(const Reduce* op) {
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
Mutate_(const Reduce* op, const Expr& self) { Mutate_(const Reduce* op, const Expr& self) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
parent_->Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary. // Recursively call simplification when necessary.
Expr ret = IRMutator::Mutate_(op, self); Expr ret = RewriteSimplifier::Impl::Mutate_(op, self);
op = ret.as<Reduce>(); op = ret.as<Reduce>();
// already been simplified by const reduction axis removal // already been simplified by const reduction axis removal
if (op == nullptr) return ret; if (op == nullptr) return ret;
...@@ -1121,7 +1117,6 @@ void CanonicalSimplifier::Update(const Var& var, ...@@ -1121,7 +1117,6 @@ void CanonicalSimplifier::Update(const Var& var,
impl_->Update(var, info, override); impl_->Update(var, info, override);
} }
CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent)
: impl_(new Impl(parent)) { : impl_(new Impl(parent)) {
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arithmetic/ir_mutator_with_analyzer.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/expr_operator.h>
#include "ir_mutator_with_analyzer.h"
namespace tvm {
namespace arith {
using namespace ir;
Stmt IRMutatorWithAnalyzer::
Mutate_(const For* op, const Stmt& s) {
analyzer_->Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRMutator::Mutate_(op, s);
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case, else_case;
{
With<ConstraintContext> ctx(analyzer_, condition);
then_case = this->Mutate(op->then_case);
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(Not::make(condition)));
else_case = this->Mutate(op->else_case);
}
if (is_one(condition)) return then_case;
if (is_zero(condition)) {
if (else_case.defined()) {
return else_case;
}
return Evaluate::make(0);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_->Bind(iv->var,
Range::make_by_min_extent(0, op->value));
Stmt stmt = IRMutator::Mutate_(op, s);
return stmt;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const AssertStmt* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
With<ConstraintContext> ctx(analyzer_, condition);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
} else {
return AssertStmt::make(condition, message, body);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Call* op, const Expr& self) {
// add condition context to if_then_else
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = Mutate(op->args[0]);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = Mutate(op->args[1]);
}
{
With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond)));
false_value = Mutate(op->args[2]);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
if (cond.same_as(op->args[0]) &&
true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
return self;
} else {
return Call::make(op->type, op->name,
{cond, true_value, false_value},
op->call_type);
}
}
return IRMutator::Mutate_(op, self);
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
} else {
return Let::make(op->var, value, body);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Select* op, const Expr& self) {
Expr cond = Mutate(op->condition);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = Mutate(op->true_value);
}
{
With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond)));
false_value = Mutate(op->false_value);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
// normal path
if (cond.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
return self;
} else {
return Select::make(cond, true_value, false_value);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Reduce* op, const Expr& self) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_->Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary.
return IRMutator::Mutate_(op, self);
}
} // namespace arith
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arithmetic/ir_mutator_with_analyzer.h
* \brief IR mutator base-class with an analyzer context.
*/
#ifndef TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
#define TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
namespace tvm {
namespace arith {
/*!
* \brief IRMutator with an analyzer context.
*
* This class can sub-classed by ir mutators that need an analyzer.
* It will populates scope-related info such as bounds of loop-variables and constraints
* for the analyzer, so that the child class can do accurate context-dependent analysis.
*
* \sa src/arithmetic/ir_mutator_with_analyzer.cc
*/
class IRMutatorWithAnalyzer : public ir::IRMutator {
public:
explicit IRMutatorWithAnalyzer(Analyzer* analyzer)
: analyzer_(analyzer) {}
// override functions that need to populate the context information.
Stmt Mutate_(const ir::For* op, const Stmt& self) override;
Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override;
Stmt Mutate_(const ir::IfThenElse* op, const Stmt& self) override;
Stmt Mutate_(const ir::AttrStmt* op, const Stmt& self) override;
Stmt Mutate_(const ir::AssertStmt* op, const Stmt& self) override;
Expr Mutate_(const ir::Let* op, const Expr& self) override;
Expr Mutate_(const ir::Select* op, const Expr& self) override;
Expr Mutate_(const ir::Call* op, const Expr& self) override;
Expr Mutate_(const ir::Reduce* op, const Expr& self) override;
protected:
/*! \brief internal analyzer field. */
Analyzer* analyzer_;
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file rewrite_simplify.cc * \file rewrite_simplify.cc
* \brief Rewrite-rule based simplification. * \brief Rewrite-rule based simplification.
*/ */
...@@ -80,7 +79,7 @@ TryCompare(const Expr& x, int64_t val) { ...@@ -80,7 +79,7 @@ TryCompare(const Expr& x, int64_t val) {
return kLT; return kLT;
} }
} }
ConstIntBound dbound = parent_->const_int_bound(diff); ConstIntBound dbound = analyzer_->const_int_bound(diff);
if (dbound->min_value > val) { if (dbound->min_value > val) {
return kGT; return kGT;
} }
...@@ -94,7 +93,7 @@ TryCompare(const Expr& x, int64_t val) { ...@@ -94,7 +93,7 @@ TryCompare(const Expr& x, int64_t val) {
return kLE; return kLE;
} }
if (val == 0) { if (val == 0) {
ModularSet dmod = parent_->modular_set(diff); ModularSet dmod = analyzer_->modular_set(diff);
if (dmod->base != 0) { if (dmod->base != 0) {
return kNE; return kNE;
} }
...@@ -490,7 +489,7 @@ Mutate_(const Div* op, const Expr& self) { ...@@ -490,7 +489,7 @@ Mutate_(const Div* op, const Expr& self) {
} }
// If all possible indices in ramp are the same. // If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) { if (CanProveGreaterEqual(b1.Eval(), 0)) {
ModularSet bmod = parent_->modular_set(b1.Eval()); ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = bmod->base / c2val; int64_t ramp_min = bmod->base / c2val;
int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) { if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
...@@ -692,7 +691,7 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -692,7 +691,7 @@ Mutate_(const Mod* op, const Expr& self) {
} }
// If all possible indices in ramp are the same. // If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) { if (CanProveGreaterEqual(b1.Eval(), 0)) {
ModularSet bmod = parent_->modular_set(b1.Eval()); ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = bmod->base / c2val; int64_t ramp_min = bmod->base / c2val;
int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
if (bmod->coeff % c2val == 0) { if (bmod->coeff % c2val == 0) {
...@@ -740,7 +739,7 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -740,7 +739,7 @@ Mutate_(const Mod* op, const Expr& self) {
// try modular analysis // try modular analysis
if ((x % c1).Match(ret)) { if ((x % c1).Match(ret)) {
ModularSet mod = parent_->modular_set(x.Eval()); ModularSet mod = analyzer_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value; int64_t c1val = c1.Eval()->value;
if (mod->coeff % c1val == 0 && if (mod->coeff % c1val == 0 &&
c1val > 0 && c1val > 0 &&
...@@ -777,7 +776,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { ...@@ -777,7 +776,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval(); return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval();
} }
// If all possible indices in ramp are the same. // If all possible indices in ramp are the same.
ModularSet bmod = parent_->modular_set(b1.Eval()); ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val); int64_t ramp_min = floordiv(bmod->base, c2val);
int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) { if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
...@@ -923,7 +922,7 @@ Mutate_(const FloorMod* op, const Expr& self) { ...@@ -923,7 +922,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
return broadcast(floormod(b1, c2), lanes).Eval(); return broadcast(floormod(b1, c2), lanes).Eval();
} }
// If all possible indices in ramp are the same. // If all possible indices in ramp are the same.
ModularSet bmod = parent_->modular_set(b1.Eval()); ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val); int64_t ramp_min = floordiv(bmod->base, c2val);
int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
if (bmod->coeff % c2val == 0) { if (bmod->coeff % c2val == 0) {
...@@ -956,7 +955,7 @@ Mutate_(const FloorMod* op, const Expr& self) { ...@@ -956,7 +955,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
// try modular analysis // try modular analysis
if (floormod(x, c1).Match(ret)) { if (floormod(x, c1).Match(ret)) {
ModularSet mod = parent_->modular_set(x.Eval()); ModularSet mod = analyzer_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value; int64_t c1val = c1.Eval()->value;
if (mod->coeff % c1val == 0 && c1val > 0) { if (mod->coeff % c1val == 0 && c1val > 0) {
return floormod(mod->base, c1).Eval(); return floormod(mod->base, c1).Eval();
...@@ -990,8 +989,8 @@ Mutate_(const Min* op, const Expr& self) { ...@@ -990,8 +989,8 @@ Mutate_(const Min* op, const Expr& self) {
TVM_TRY_REWRITE(min(x, x), x); TVM_TRY_REWRITE(min(x, x), x);
// constant int bound // constant int bound
ConstIntBound a_bound = parent_->const_int_bound(op->a); ConstIntBound a_bound = analyzer_->const_int_bound(op->a);
ConstIntBound b_bound = parent_->const_int_bound(op->b); ConstIntBound b_bound = analyzer_->const_int_bound(op->b);
if (a_bound->max_value <= b_bound->min_value) { if (a_bound->max_value <= b_bound->min_value) {
return op->a; return op->a;
} }
...@@ -1175,8 +1174,8 @@ Mutate_(const Max* op, const Expr& self) { ...@@ -1175,8 +1174,8 @@ Mutate_(const Max* op, const Expr& self) {
TVM_TRY_REWRITE(max(x, x), x); TVM_TRY_REWRITE(max(x, x), x);
// constant int bound // constant int bound
ConstIntBound a_bound = parent_->const_int_bound(op->a); ConstIntBound a_bound = analyzer_->const_int_bound(op->a);
ConstIntBound b_bound = parent_->const_int_bound(op->b); ConstIntBound b_bound = analyzer_->const_int_bound(op->b);
if (a_bound->min_value >= b_bound->max_value) { if (a_bound->min_value >= b_bound->max_value) {
return op->a; return op->a;
} }
...@@ -1658,32 +1657,9 @@ Mutate_(const Or* op, const Expr& self) { ...@@ -1658,32 +1657,9 @@ Mutate_(const Or* op, const Expr& self) {
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Select* op, const Expr& self) { Mutate_(const Select* op, const Expr& self) {
Expr cond = Mutate(op->condition); Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(parent_, cond);
true_value = Mutate(op->true_value);
}
{
With<ConstraintContext> constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->false_value);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
// normal path
Expr ret;
if (cond.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
ret = self;
} else {
ret = Select::make(cond, true_value, false_value);
}
op = ret.as<Select>(); op = ret.as<Select>();
if (op == nullptr) return ret;
// Pattern var to match any expression // Pattern var to match any expression
PVar<Expr> x, y; PVar<Expr> x, y;
TVM_TRY_REWRITE(select(x, y, y), y); TVM_TRY_REWRITE(select(x, y, y), y);
...@@ -1693,37 +1669,9 @@ Mutate_(const Select* op, const Expr& self) { ...@@ -1693,37 +1669,9 @@ Mutate_(const Select* op, const Expr& self) {
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Call* op, const Expr& self) { Mutate_(const Call* op, const Expr& self) {
// add condition context to if_then_else // add condition context to if_then_else
Expr ret; Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self);
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = Mutate(op->args[0]);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(parent_, cond);
true_value = Mutate(op->args[1]);
}
{
With<ConstraintContext> constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->args[2]);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
if (cond.same_as(op->args[0]) &&
true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
ret = self;
} else {
ret = Call::make(op->type, op->name,
{cond, true_value, false_value},
op->call_type);
}
} else {
ret = IRMutator::Mutate_(op, self);
}
op = ret.as<Call>(); op = ret.as<Call>();
if (op == nullptr) return ret;
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) { if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0]; return op->args[0];
} }
...@@ -1731,23 +1679,6 @@ Mutate_(const Call* op, const Expr& self) { ...@@ -1731,23 +1679,6 @@ Mutate_(const Call* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Let* op, const Expr& self) {
// For now assume value does not has side-effect.
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
parent_->Bind(op->var, value);
return this->Mutate(op->body);
}
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
} else {
return Let::make(op->var, value, body);
}
}
Expr RewriteSimplifier::Impl::
Mutate_(const Variable* op, const Expr& self) { Mutate_(const Variable* op, const Expr& self) {
Var var = GetRef<Var>(op); Var var = GetRef<Var>(op);
auto it = var_map_.find(var); auto it = var_map_.find(var);
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file rewrite_simplify.h * \file rewrite_simplify.h
* \brief Rewrite-rule based simplification. * \brief Rewrite-rule based simplification.
*/ */
...@@ -31,6 +30,7 @@ ...@@ -31,6 +30,7 @@
#include <unordered_map> #include <unordered_map>
#include "const_fold.h" #include "const_fold.h"
#include "pattern_match.h" #include "pattern_match.h"
#include "ir_mutator_with_analyzer.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -42,10 +42,12 @@ using namespace ir; ...@@ -42,10 +42,12 @@ using namespace ir;
* *
* This class can be inheritated for other simplifiers. * This class can be inheritated for other simplifiers.
*/ */
class RewriteSimplifier::Impl : public IRMutator { class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
public: public:
using IRMutatorWithAnalyzer::Mutate_;
explicit Impl(Analyzer* parent) explicit Impl(Analyzer* parent)
: parent_(parent) {} : IRMutatorWithAnalyzer(parent) {}
void Update(const Var& var, const Expr& info, bool override); void Update(const Var& var, const Expr& info, bool override);
Expr Mutate_(const Add* op, const Expr& self) override; Expr Mutate_(const Add* op, const Expr& self) override;
...@@ -68,7 +70,6 @@ class RewriteSimplifier::Impl : public IRMutator { ...@@ -68,7 +70,6 @@ class RewriteSimplifier::Impl : public IRMutator {
Expr Mutate_(const Not* op, const Expr& self) override; Expr Mutate_(const Not* op, const Expr& self) override;
Expr Mutate_(const Select* op, const Expr& self) override; Expr Mutate_(const Select* op, const Expr& self) override;
Expr Mutate_(const Call* op, const Expr& self) override; Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override; Expr Mutate_(const Variable* op, const Expr& self) override;
Expr Mutate_(const Cast* op, const Expr& self) override; Expr Mutate_(const Cast* op, const Expr& self) override;
...@@ -83,8 +84,6 @@ class RewriteSimplifier::Impl : public IRMutator { ...@@ -83,8 +84,6 @@ class RewriteSimplifier::Impl : public IRMutator {
kLE, kLE,
kNE kNE
}; };
// reference to the main analyzer
Analyzer* parent_;
// counter to record recursive rewrite depth. // counter to record recursive rewrite depth.
int recur_depth_{0}; int recur_depth_{0};
// internal variable map // internal variable map
...@@ -103,7 +102,7 @@ class RewriteSimplifier::Impl : public IRMutator { ...@@ -103,7 +102,7 @@ class RewriteSimplifier::Impl : public IRMutator {
private: private:
// Whether x >= val // Whether x >= val
bool CanProveGreaterEqual(const Expr& x, int64_t val) { bool CanProveGreaterEqual(const Expr& x, int64_t val) {
return parent_->CanProveGreaterEqual(x, val); return analyzer_->CanProveGreaterEqual(x, val);
} }
// Whether x == val // Whether x == val
bool CanProveEqual(const Expr& x, int64_t val) { bool CanProveEqual(const Expr& x, int64_t val) {
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file stmt_simplify.cc * \file stmt_simplify.cc
* \brief Statement simplifier based on analyzer * \brief Statement simplifier based on analyzer
*/ */
...@@ -28,113 +27,32 @@ ...@@ -28,113 +27,32 @@
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include "ir_mutator_with_analyzer.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace ir;
class StmtSimplifier : public IRMutator { class StmtSimplifier : public IRMutatorWithAnalyzer {
public: public:
using IRMutator::Mutate; explicit StmtSimplifier(Analyzer* analyzer)
: IRMutatorWithAnalyzer(analyzer) {}
using Parent = IRMutatorWithAnalyzer;
using Parent::Mutate;
using Parent::Mutate_;
Expr Mutate(Expr expr) final { Expr Mutate(Expr expr) final {
return analyzer_.Simplify(expr); return analyzer_->Simplify(expr);
} }
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) { Stmt Simplify(Stmt stmt) {
for (auto kv : vrange) {
analyzer_.Bind(kv.first, kv.second);
}
return Mutate(stmt); return Mutate(stmt);
} }
Stmt Mutate_(const For* op, const Stmt& s) final {
analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_.Bind(op->var, value);
return this->Mutate(op->body);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case, else_case;
{
With<ConstraintContext> ctx(&analyzer_, condition);
then_case = this->Mutate(op->then_case);
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(&analyzer_, Mutate(Not::make(condition)));
else_case = this->Mutate(op->else_case);
}
if (is_one(condition)) return then_case;
if (is_zero(condition)) {
if (else_case.defined()) {
return else_case;
}
return Evaluate::make(0);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
// AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
if (!var_dom_.count(iv->var.get())) {
Range dom = Range::make_by_min_extent(0, op->value);
var_dom_[iv->var.get()] = dom;
analyzer_.Bind(iv->var, dom);
}
Stmt stmt = IRMutator::Mutate_(op, s);
return stmt;
} else {
return IRMutator::Mutate_(op, s);
}
}
// AssertStmt
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
With<ConstraintContext> ctx(&analyzer_, condition);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
} else {
return AssertStmt::make(condition, message, body);
}
}
// eliminate useless stores // eliminate useless stores
Stmt Mutate_(const Store* op, const Stmt& s) { Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>(); op = stmt.as<Store>();
if (const Load* load = op->value.as<Load>()) { if (const Load* load = op->value.as<Load>()) {
...@@ -145,11 +63,6 @@ class StmtSimplifier : public IRMutator { ...@@ -145,11 +63,6 @@ class StmtSimplifier : public IRMutator {
} }
return stmt; return stmt;
} }
protected:
Analyzer analyzer_;
// variable domain
std::unordered_map<const Variable*, Range> var_dom_;
}; };
} // namespace arith } // namespace arith
...@@ -157,8 +70,11 @@ class StmtSimplifier : public IRMutator { ...@@ -157,8 +70,11 @@ class StmtSimplifier : public IRMutator {
namespace ir { namespace ir {
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) { Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::StmtSimplifier().Simplify( arith::Analyzer analyzer;
stmt, vrange); for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
}
return arith::StmtSimplifier(&analyzer).Simplify(stmt);
} }
Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) { Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
...@@ -179,8 +95,7 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) { ...@@ -179,8 +95,7 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) {
} }
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) { Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::StmtSimplifier().Simplify( return CanonicalSimplify(stmt, vrange);
stmt, vrange);
} }
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -30,5 +30,23 @@ def test_stmt_simplify(): ...@@ -30,5 +30,23 @@ def test_stmt_simplify():
assert isinstance(body.body, tvm.stmt.Store) assert isinstance(body.body, tvm.stmt.Store)
def test_thread_extent_simplify():
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="C")
n = tvm.var("n")
tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y")
ib.scope_attr(tx, "thread_extent", n)
ib.scope_attr(tx, "thread_extent", n)
ib.scope_attr(ty, "thread_extent", 1)
with ib.if_scope(tx + ty < 12):
A[tx] = C[tx + ty]
body = tvm.stmt.LetStmt(n, 10, ib.get())
body = tvm.ir_pass.CanonicalSimplify(body)
assert isinstance(body.body.body.body, tvm.stmt.Store)
if __name__ == "__main__": if __name__ == "__main__":
test_stmt_simplify() test_stmt_simplify()
test_thread_extent_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment