Unverified Commit 37e57548 by yongfeng-nv Committed by GitHub

Improve IntervalSet's floormod (#5367)

parent 4a3fece7
......@@ -138,8 +138,9 @@ class ConstIntBoundAnalyzer {
*
* \param var The variable.
* \param range The range we bind to.
* \param override Whether we allow overriding an existing var's range.
*/
TVM_DLL void Bind(const Var& var, const Range& range);
TVM_DLL void Bind(const Var& var, const Range& range, bool override = false);
private:
friend class Analyzer;
......@@ -411,8 +412,9 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param expr The expression we bind to.
* \param override Whether we allow overriding an existing var's expression.
*/
void Bind(const Var& var, const PrimExpr& expr);
void Bind(const Var& var, const PrimExpr& expr, bool override = false);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
......@@ -421,14 +423,16 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param range The range we bind to.
* \param override Whether we allow overriding an existing var's expression.
*/
void Bind(const Var& var, const Range& range);
void Bind(const Var& var, const Range& range, bool override = false);
/*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
* \param override Whether we allow overriding an existing var's expression.
*/
void Bind(const Map<Var, Range>& variables);
void Bind(const Map<Var, Range>& variables, bool override = false);
/*!
* \brief Whether can we prove expr >= val.
......@@ -443,6 +447,19 @@ class TVM_DLL Analyzer {
*/
bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
/*!
* \brief Whether can we prove expr < val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param upper_bound The upper bound.
* \return Whether we can prove it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
/*!
* \brief Whether can we prove condition.
*
* \param cond The expression to be proved.
......
......@@ -153,6 +153,13 @@ class IntSet : public ObjectRef {
// Integer set legacy API.
//------------------------------------------------
/*!
* \brief Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet>
*
* \param dom_map The domain map to convert.
* \return The converted map.
*/
Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
*
......@@ -160,8 +167,7 @@ class IntSet : public ObjectRef {
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(PrimExpr e,
const Map<IterVar, IntSet>& dom_map);
IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
......@@ -171,7 +177,6 @@ IntSet EvalSet(PrimExpr e,
*/
IntSet EvalSet(PrimExpr e,
const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
......@@ -202,7 +207,6 @@ IntSet EvalSet(IntSet s,
*/
IntSet EvalSet(Range r,
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
/*!
......
......@@ -36,31 +36,31 @@ Analyzer::Analyzer()
int_set(this) {
}
void Analyzer::Bind(const Var& var, const PrimExpr& expr) {
void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);
this->const_int_bound.Update(var, this->const_int_bound(new_expr));
this->modular_set.Update(var, this->modular_set(new_expr));
this->rewrite_simplify.Update(var, new_expr);
this->canonical_simplify.Update(var, new_expr);
this->const_int_bound.Update(var, this->const_int_bound(new_expr), override);
this->modular_set.Update(var, this->modular_set(new_expr), override);
this->rewrite_simplify.Update(var, new_expr, override);
this->canonical_simplify.Update(var, new_expr, override);
}
void Analyzer::Bind(const Var& var, const Range& range) {
void Analyzer::Bind(const Var& var, const Range& range, bool override) {
CHECK(range.defined());
if (tir::is_one(range->extent)) {
this->Bind(var, range->min);
this->Bind(var, range->min, override);
} else {
this->const_int_bound.Bind(var, range);
this->const_int_bound.Bind(var, range, override);
}
// skip modular_set
// skip rewrite simplify
}
void Analyzer::Bind(const Map<Var, Range>& variables) {
void Analyzer::Bind(const Map<Var, Range>& variables, bool override) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second);
this->Bind(iter.first, iter.second, override);
}
}
......@@ -92,6 +92,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
return false;
}
bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
if (const auto* ptr = expr.as<tir::IntImmNode>()) {
return ptr->value < upper_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->max_value < upper_bound) return true;
return false;
}
bool Analyzer::CanProve(const PrimExpr& expr) {
if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0;
......
......@@ -99,13 +99,13 @@ class ConstIntBoundAnalyzer::Impl :
}
};
void Bind(const Var& var, const Range& range) {
void Bind(const Var& var, const Range& range, bool override) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Entry ret;
ret.min_value = a.min_value;
ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
Update(var, ret, false);
Update(var, ret, override);
}
void Update(const Var& var,
......@@ -150,10 +150,12 @@ class ConstIntBoundAnalyzer::Impl :
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
if (val != bound_->end()) {
CHECK(val->second->min_value == res.min_value &&
val->second->max_value == res.max_value)
<< "Detected bound for " << expr
<< "conflicts with memorization";
auto everything = Everything(op->dtype);
CHECK(
(val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
(val->second->min_value == everything.min_value &&
val->second->max_value == everything.max_value))
<< "Detected bound for " << expr << "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
}
......@@ -574,8 +576,8 @@ void ConstIntBoundAnalyzer::Update(const Var& var,
impl_->Update(var, info, override);
}
void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
impl_->Bind(var, range);
void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) {
impl_->Bind(var, range, override);
}
std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
......
......@@ -311,6 +311,16 @@ inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer,
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
if (divisor.as<tir::IntImmNode>()) {
// a mod b = a - (a / b) * b if a_max / b == a_min / b
auto qmax = floordiv(a->max_value, divisor);
auto qmin = floordiv(a->min_value, divisor);
if (analyzer->CanProve(qmax == qmin)) {
auto tmax = a->max_value - divisor * qmin;
auto tmin = a->min_value - divisor * qmin;
return IntervalSet(tmin, tmax);
}
}
return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
} else {
PrimExpr bound = abs(divisor) - 1;
......
......@@ -231,7 +231,7 @@ void ComputeOpNode::PropBoundToInputs(
// undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection.
IntSet arg_intset = EvalSet(call->args[i], dom_map);
IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map));
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
......@@ -239,12 +239,14 @@ void ComputeOpNode::PropBoundToInputs(
PrimExpr min_value = arg_interval->min_value;
PrimExpr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter.
if (arith::is_neg_inf(min_value) ||
analyzer->CanProve(shape_i_min_value >= min_value)) {
// We must update bound's ends in pairs. Here is an counter example: shape_i is
// [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is
// [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0],
// awkward for further analysis.
if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) ||
(analyzer->CanProve(shape_i_min_value >= min_value) &&
analyzer->CanProve(shape_i_max_value <= max_value))) {
min_value = shape_i_min_value;
}
if (arith::is_pos_inf(max_value) ||
analyzer->CanProve(shape_i_max_value <= max_value)) {
max_value = shape_i_max_value;
}
dom.data[i].push_back(IntSet::interval(min_value, max_value));
......
......@@ -137,7 +137,7 @@ void InferRootBound(const Stage& stage,
Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
// The parent set.
for (const Operation& op : consumers) {
std::unordered_map<const VarNode*, IntSet> relax_set;
Map<Var, IntSet> relax_set;
std::unordered_map<IterVar, IntSet> up_state;
bool found_attach = false;
CHECK(ctx.op2stage_.count(op.get()));
......@@ -176,9 +176,9 @@ void InferRootBound(const Stage& stage,
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
relax_set[iv->var.get()] = IntSet::range(vrange);
relax_set.Set(iv->var, IntSet::range(vrange));
if (ctx.bind_map.count(iv)) {
relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange));
}
}
}
......@@ -190,6 +190,9 @@ void InferRootBound(const Stage& stage,
// Relax if needed.
std::unordered_map<const VarNode*, IntSet> dom_map;
arith::Analyzer analyzer;
for (auto entry : *rmap) {
analyzer.Bind(entry.first->var, entry.second);
}
for (auto iv : op->root_iter_vars()) {
Range r;
if (up_state.count(iv)) {
......@@ -198,11 +201,13 @@ void InferRootBound(const Stage& stage,
r = iv->dom;
}
if (relax_set.size() != 0) {
dom_map[iv->var.get()] = EvalSet(r, relax_set);
dom_map[iv->var.get()] = IntSet::interval(
analyzer.int_set(r->min, relax_set).min(),
analyzer.int_set(r->min + r->extent - 1, relax_set).max());
} else {
dom_map[iv->var.get()] = IntSet::range(r);
}
analyzer.Bind(iv->var, r);
analyzer.Bind(iv->var, r, true);
}
op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
}
......
......@@ -579,11 +579,15 @@ std::vector<PrimExpr> MakeBoundCheck(
PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
std::vector<PrimExpr> preds;
std::unordered_map<const VarNode*, IntSet> iset_dmap;
Map<Var, IntSet> iset_dmap;
// setup domain map for set analysis
for (const auto& kv : dom_map) {
iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
iset_dmap.Set(kv.first->var, IntSet::range(kv.second));
}
for (auto entry : dom_map) {
analyzer.Bind(entry.first->var, entry.second);
}
for (const IterVar& iv : stage->all_iter_vars) {
......@@ -591,7 +595,7 @@ std::vector<PrimExpr> MakeBoundCheck(
if (bound_state.at(iv)) {
Range dom = dom_map.at(iv);
PrimExpr value = value_map.at(iv) - dom->min;
PrimExpr vmax = EvalSet(value, iset_dmap).max();
PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
......@@ -603,7 +607,7 @@ std::vector<PrimExpr> MakeBoundCheck(
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
PrimExpr value = value_map.at(iv) - iv->dom->min;
IntSet s = EvalSet(value, iset_dmap);
IntSet s = analyzer.int_set(value, iset_dmap);
PrimExpr vmin = s.min();
PrimExpr vmax = s.max();
// The range of `value` resides in [vmin, vmax]
......
......@@ -90,6 +90,20 @@ def test_mod():
flm = tvm.te.floormod
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9))
floordiv = tvm.te.floordiv
z = te.var("z")
ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3))
ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)},
(0, 7))
ck1 = IntSetChecker()
ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2))
ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3))
def test_max_min():
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
def test_bound_tile_mod():
def compute(M_tiles, N_tiles, factor, dtype):
# Algo
M = M_tiles * factor
N = N_tiles * factor
A = tvm.te.placeholder((N, M), name='A', dtype=dtype)
C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C')
s = tvm.te.create_schedule(C.op)
return s, A, C
def schedule(s, factor, padding, A, C):
C_local = s.cache_write(C, "local")
n, m = C.op.axis
bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
nio, nii = s[C].split(ni, 2)
n = s[C].fuse(nii, mi)
C_shared = s.cache_write(C, "shared")
bn, bm, ni, mi = C_shared.op.axis
s[C_shared].storage_align(ni, factor * 2, padding)
n, m = s[C].op.axis
bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
s[C].set_scope("global")
niio, niii = s[C].split(ni, 32)
s[C_shared].compute_at(s[C], niio)
return s
s, A, C = compute(2, 2, 128, "float16")
s = schedule(s, 128, 8, A, C)
bounds = tvm.te.schedule.InferBound(s)
check = (bounds[s.stages[2].op.axis[2]].extent == 16)
if(not check):
print(tvm.lower(s, [A, C], simple_mode=True))
assert(check)
if __name__ == "__main__":
test_bound_tile_mod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment