Unverified Commit 79e071c9 by Tianqi Chen Committed by GitHub

[ARITH][SCHEDULE] Update schedule to use the new analyzer (#3466)

parent dfc4f972
......@@ -516,6 +516,24 @@ class Analyzer {
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
/*!
* \brief Whether can we prove condition.
*
* \param cond The expression to be proved.
* \return The result.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProve(const Expr& cond);
/*!
* \brief Simplify expr.
*
* \param expr The expression to be simplified.
* \return The result.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
Expr Simplify(const Expr& expr);
};
//-----------------------------------------------
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
......
......@@ -23,6 +23,7 @@
*/
#include <tvm/ir.h>
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
namespace tvm {
namespace arith {
......@@ -49,8 +50,13 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
}
void Analyzer::Bind(const VarExpr& v, const Range& range) {
CHECK(range.defined());
Var var(v.node_);
this->const_int_bound.Bind(var, range);
if (is_one(range->extent)) {
this->rewrite_simplify.Update(var, range->min);
this->canonical_simplify.Update(var, range->min);
}
// skip modular_set
// skip rewrite simplify
}
......@@ -82,5 +88,27 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
return false;
}
bool Analyzer::CanProve(const Expr& expr) {
if (const auto* ptr = expr.as<ir::UIntImm>()) {
return ptr->value != 0;
}
auto res = this->rewrite_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImm>()) {
return ptr->value != 0;
}
res = this->canonical_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImm>()) {
return ptr->value != 0;
}
return false;
}
Expr Analyzer::Simplify(const Expr& expr) {
if (is_const(expr)) return expr;
auto res = this->rewrite_simplify(expr);
res = this->canonical_simplify(res);
return res;
}
} // namespace arith
} // namespace tvm
......@@ -262,6 +262,7 @@ class SumExprNode : public CanonicalExprNode {
rhs.CopyOnWrite()->scale += lhs->scale;
lhs.CopyOnWrite()->scale = 0;
} else if (lhs->lower_factor == rhs->upper_factor &&
rhs->scale != 0 &&
lhs->scale % rhs->scale == 0 &&
lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor) {
// Rules used in the proof:
......
......@@ -42,11 +42,23 @@ ConstIntBound::ConstIntBound(
node_ = std::move(node);
}
inline void PrintBoundValue(std::ostream& os, int64_t val) {
if (val == ConstIntBound::kPosInf) {
os << "pos_inf";
} else if (val == ConstIntBound::kNegInf) {
os << "neg_inf";
} else {
os << val;
}
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ConstIntBoundNode *op, IRPrinter *p) {
p->stream << "ConstIntBound"
<< "[" << op->min_value << ", "
<< op->max_value << ']';
.set_dispatch<ConstIntBoundNode>([](const ConstIntBoundNode* op, IRPrinter* p) {
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
p->stream << ',';
PrintBoundValue(p->stream, op->max_value);
p->stream << ']';
});
// internal entry for const int bound
......@@ -95,7 +107,10 @@ class ConstIntBoundAnalyzer::Impl :
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info)
<< "var \'" << var << "\' already updated.";
<< "Trying to update var \'" << var << "\'"
<< " with a different const bound: "
<< "original=" << ConstIntBound(it->second.min_value, it->second.max_value)
<< ", new=" << ConstIntBound(info.min_value, info.max_value);
}
}
var_map_[var] = info;
......
......@@ -105,7 +105,14 @@ TryCompare(const Expr& x, int64_t val) {
void RewriteSimplifier::Impl::
Update(const Var& var, const Expr& info, bool override) {
if (!override) {
CHECK(!var_map_.count(var));
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(Equal(it->second, info))
<< "Trying to update var \'" << var << "\'"
<< " with a different value: "
<< "original=" << it->second
<< ", new=" << info;
}
}
var_map_[var] = info;
}
......@@ -199,6 +206,9 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1));
TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x);
TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x);
}
// condition rules.
......@@ -477,6 +487,10 @@ Mutate_(const Div* op, const Expr& self) {
}
}
TVM_TRY_REWRITE(x / x, OneWithTypeLike(x));
TVM_TRY_REWRITE(x * c1 / x, c1);
TVM_TRY_REWRITE(c1 * x / x, c1);
// Rules involving 2-operands.
TVM_TRY_REWRITE_IF((x * c1 + y) / c2, x * (c1 / c2) + y / c2,
c1.Eval()->value >= 0 &&
......@@ -684,6 +698,9 @@ Mutate_(const Mod* op, const Expr& self) {
if (mod->coeff % c1val == 0 &&
CanProveGreaterEqual(x.Eval(), 0)) {
return (mod->base % c1).Eval();
} else if (mod->coeff % c1val == 0 &&
mod->base % c1val == 0) {
return make_zero(ret.type());
}
}
}
......
......@@ -121,6 +121,11 @@ class RewriteSimplifier::Impl : public IRMutator {
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
}
template<typename TA>
PConstWithTypeLike<TA> OneWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 1);
}
};
......
......@@ -213,6 +213,8 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
// Prepare context
GraphContext ctx;
Array<Operation> roots;
arith::Analyzer analyzer;
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
......@@ -233,16 +235,26 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, ctx, &ret);
// bind bound of root iter vars.
for (auto iv : stage->op->root_iter_vars()) {
auto it = ret.find(iv);
if (it != ret.end()) {
analyzer.Bind(iv->var, it->second);
}
}
// pass down to get bound of all iter vars.
PassDownDomain(stage, &ret);
PassDownDomain(stage, &ret, &analyzer);
for (IterVar iv : stage->env_threads) {
CHECK(iv->dom.defined());
ret[iv] = iv->dom;
}
}
for (auto& p : ret) {
ret[p.first] = Range::make_by_min_extent(ir::Simplify(p.second->min),
ir::Simplify(p.second->extent));
ret[p.first] = Range::make_by_min_extent(
analyzer.Simplify(p.second->min),
analyzer.Simplify(p.second->extent));
}
return Map<IterVar, Range>(ret.begin(), ret.end());
}
......
......@@ -34,24 +34,17 @@ namespace schedule {
using namespace ir;
using namespace arith;
// result = ceil((a / b)), both a and b are positive integer
inline Expr DivCeil(Expr a, Expr b) {
return ir::Simplify((a + b - 1) / b);
}
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
void Update(std::unordered_map<IterVar, Range>* p_state,
const IterVar& iv,
Range r) {
Range r,
Analyzer* analyzer) {
auto it = p_state->find(iv);
if (it == p_state->end()) {
(*p_state)[iv] = r;
analyzer->Bind(iv->var, r);
} else {
bool match = is_zero(it->second->min);
if (!prove_equal(r->extent, it->second->extent)) match = false;
bool match = is_zero(it->second->min) &&
analyzer->CanProve(r->extent - it->second->extent == 0);
CHECK(match)
<< iv
<< " domain already inferred,"
......@@ -62,7 +55,12 @@ void Update(std::unordered_map<IterVar, Range>* p_state,
void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
arith::Analyzer* actx,
bool allow_missing) {
auto ceil_div = [actx](Expr a, Expr b) {
return actx->Simplify((a + (b - 1)) / b);
};
auto& state = *p_state;
// forwar iteration on relations
for (IterVarRelation rel : stage->relations) {
......@@ -74,15 +72,16 @@ void PassDownDomain(const Stage& stage,
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) {
Update(p_state, r->inner, Range::make_by_min_extent(0, r->factor));
Update(p_state, r->inner,
Range::make_by_min_extent(0, r->factor), actx);
Update(p_state, r->outer,
Range::make_by_min_extent(
0, DivCeil(range_parent->extent, r->factor)));
0, ceil_div(range_parent->extent, r->factor)), actx);
} else {
Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts));
Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
Update(p_state, r->inner,
Range::make_by_min_extent(
0, DivCeil(range_parent->extent, r->nparts)));
0, ceil_div(range_parent->extent, r->nparts)), actx);
}
} else if (const FuseNode* r = rel.as<FuseNode>()) {
if (!state.count(r->outer) || !state.count(r->inner)) {
......@@ -100,9 +99,9 @@ void PassDownDomain(const Stage& stage,
}
Update(p_state, r->rebased,
Range::make_by_min_extent(
0, state.at(r->parent)->extent));
0, state.at(r->parent)->extent), actx);
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
Update(p_state, s->iter, Range::make_by_min_extent(0, 1));
Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx);
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -111,7 +110,7 @@ void PassDownDomain(const Stage& stage,
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(state.count(kv.first));
Update(p_state, kv.second->bind_thread, state.at(kv.first));
Update(p_state, kv.second->bind_thread, state.at(kv.first), actx);
}
}
}
......
......@@ -43,11 +43,13 @@ namespace schedule {
*
* \param stage The stage to operate on.
* \param p_state The state of the message passing.
* \param analyzer Analyzer context, storing information about bounds in p_state.
* \param allow_missing Whether allow missing value.
*/
void PassDownDomain(
const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
arith::Analyzer* analyzer,
bool allow_missing = false);
/*!
......
......@@ -203,14 +203,16 @@ void PrepareAxisMapping(Stage orig_stage,
auto& vsub = *p_vsub;
auto& vsub2newvar = *p_vsub2newvar;
auto& predicates = *p_predicates;
arith::Analyzer analyzer;
for (IterVar iv : op->reduce_axis) {
red_axis.insert(iv);
}
for (IterVar iv : op->axis) {
dom_map[iv] = iv->dom;
analyzer.Bind(iv->var, iv->dom);
}
schedule::PassDownDomain(orig_stage, &dom_map, true);
schedule::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
{
// The source->cache
std::unordered_map<IterVar, Expr> value_map;
......@@ -679,6 +681,8 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
<< "Factor axis touches normal axis.";
skip_bound_check.insert(iv);
}
// get analyzer.
arith::Analyzer analyzer;
// Get the replace index
std::unordered_map<IterVar, Range> dom_map;
std::unordered_map<IterVar, Expr> value_map;
......@@ -688,8 +692,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
} else {
skip_bound_check.insert(iv);
}
analyzer.Bind(iv->var, iv->dom);
}
schedule::PassDownDomain(reduce_stage, &dom_map, true);
schedule::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv)) {
Range dom = dom_map.at(iv);
......
......@@ -198,6 +198,12 @@ def test_complex_cases():
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127))
ck.verify(res2, 1)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1024), True)
res3 = ((((((((((x*1024) + y)/65536) + ((((x*1024) + y) % 65536)/256))
+ ((((x*1024) + y) % 256)/16)) + (((x*1024) + y) % 16)) - (y/256)) -
((y % 256)/16)) - (y % 16)) - (x*4))
ck.verify(res3, ((((x*1024) + y)/256) - (y/256)) - (x*4))
if __name__ == "__main__":
test_simplify_if_then_else()
......
......@@ -271,6 +271,8 @@ def test_mul_index_simplify():
def test_div_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(x / x, 1)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True)
......@@ -311,6 +313,7 @@ def test_div_index_simplify():
ck.verify((y + z * x) / z, y / z + x)
def test_mod_index_simplify():
ck = RewriteChecker()
x, y, nx, ny, z = tvm.var("x"), tvm.var("y"), tvm.var("nx"), tvm.var("ny"), tvm.var("z")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment