Unverified Commit 585f9ce6 by Lianmin Zheng Committed by GitHub

Tighten split's extent (#4931)

* Set split node's range to minimum of ext and split factor or split nparts, but only when PassDownDomain is called with allow_missing == false, i.e. by InferBound.  Add a helper PassUpThreadBinding() to get a map telling whether an IterVar has at least one leaf IterVar deriving from it binding to a thread. Add two unit tests.

* Enhance LoopVectorizer for vectorizing by 0.  Found at least one case from testtopi/tests/python/test_topi_transform.py::test_tile.

* Revert changes vectorize_loop.cc; when parent's ext is zero, set split's range to the factor or nparts.

* Update with comments.

* Refactor the ext tightening predicate.

* Fix reference types.

* Integrate tvm.te changes.

* Trivial comment change to trigger CI.

* Trivial comment correction to trigger testing.
parent 5a0f39b5
...@@ -51,17 +51,66 @@ void Update(std::unordered_map<IterVar, Range>* p_state, ...@@ -51,17 +51,66 @@ void Update(std::unordered_map<IterVar, Range>* p_state,
} }
} }
/*!
* \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to
* a thread.
*
* \param stage The stage to operate on.
* \param p_state The propagation result of each IterVar.
*/
void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>* p_state) {
auto bound_to_thread = [&stage](const IterVar& iv) {
bool bound = false;
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end()) {
bound = (*it).second->bind_thread.defined();
}
return bound;
};
auto& state = *p_state;
// Fill p_state with leaf itervars
for (const IterVar& iv : stage->leaf_iter_vars) {
state[iv] = bound_to_thread(iv);
}
// Traverse the graph bottom-up to propagate thread binding information
for (size_t i = stage->relations.size(); i != 0; --i) {
IterVarRelation rel = stage->relations[i - 1];
if (const SplitNode* s = rel.as<SplitNode>()) {
state[s->parent] = state[s->inner] || state[s->outer];
} else if (const FuseNode* s = rel.as<FuseNode>()) {
state[s->inner] = state[s->fused];
state[s->outer] = state[s->fused];
} else if (const RebaseNode* s = rel.as<RebaseNode>()) {
state[s->parent] = state[s->rebased];
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
void PassDownDomain(const Stage& stage, void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state, std::unordered_map<IterVar, Range>* p_state,
arith::Analyzer* actx, arith::Analyzer* actx,
bool allow_missing) { bool allow_missing) {
auto ceil_div = [actx](PrimExpr a, PrimExpr b) { auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
if (actx->CanProve(indexmod(a, b) == 0)) { if (actx->CanProve(indexmod(a, b) == 0)) {
return actx->Simplify(indexdiv(a, b)); return actx->Simplify(indexdiv(a, b));
} }
return actx->Simplify(indexdiv(a + (b - 1), b)); return actx->Simplify(indexdiv(a + (b - 1), b));
}; };
auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) {
if (actx->CanProve(a < b)) {
return actx->Simplify(a);
}
return actx->Simplify(b);
};
std::unordered_map<IterVar, bool> dominating_thread;
PassUpThreadBinding(stage, &dominating_thread);
auto& state = *p_state; auto& state = *p_state;
// forwar iteration on relations // forwar iteration on relations
for (IterVarRelation rel : stage->relations) { for (IterVarRelation rel : stage->relations) {
...@@ -72,14 +121,35 @@ void PassDownDomain(const Stage& stage, ...@@ -72,14 +121,35 @@ void PassDownDomain(const Stage& stage,
} }
CHECK(!state.count(r->inner)); CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent); const Range& range_parent = state.at(r->parent);
// Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the
// following conditions are met:
// 1. No leaf IterVar derived from iv binds to any thread. People may use split
// to force an IterVar extent to match the number of allocated threads to fuse stages
// that require different number of threads. We don't want to change these extents.
// 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound,
// rather than by an early compiler phase, such as rfactor(). We don't want to tighten an
// IterVar in an early phase allowing missing IterVars, because it may bind to a thread later.
// 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one
// zero-sized dimension. Split creates iv with a positive extent to avoid zero-extent
// IterVar. We don't touch it.
auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) {
return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent)
? factor_or_nparts
: minimum_or_later(range_parent->extent, factor_or_nparts);
};
if (r->factor.defined()) { if (r->factor.defined()) {
Update(p_state, r->inner, Update(p_state, r->inner,
Range::make_by_min_extent(0, r->factor), actx); Range::make_by_min_extent(
0, resolve_min_extent_for_split(r->inner, r->factor)),
actx);
Update(p_state, r->outer, Update(p_state, r->outer,
Range::make_by_min_extent( Range::make_by_min_extent(
0, ceil_div(range_parent->extent, r->factor)), actx); 0, ceil_div(range_parent->extent, r->factor)), actx);
} else { } else {
Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx); Update(p_state, r->outer,
Range::make_by_min_extent(
0, resolve_min_extent_for_split(r->outer, r->nparts)),
actx);
Update(p_state, r->inner, Update(p_state, r->inner,
Range::make_by_min_extent( Range::make_by_min_extent(
0, ceil_div(range_parent->extent, r->nparts)), actx); 0, ceil_div(range_parent->extent, r->nparts)), actx);
......
...@@ -70,6 +70,32 @@ def test_bound3(): ...@@ -70,6 +70,32 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16) assert(bounds[A1.op.axis[1]].extent.value==16)
def test_bound_split_ext_less_than_factor():
m = 8
I = te.placeholder((m,), name='I')
EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
s = te.create_schedule([E.op])
xo, xi = s[E].split(s[E].op.axis[0], factor = 32)
s[EF].compute_at(s[E], xo)
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xi].extent.value == m
def test_bound_split_ext_less_than_naprts():
m = 8
I = te.placeholder((m,), name='I')
EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
s = te.create_schedule([E.op])
xo, xi = s[E].split(s[E].op.axis[0], nparts = 32)
s[EF].compute_at(s[E], xo)
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xo].extent.value == m
def test_bound_split_divisible(): def test_bound_split_divisible():
m = te.var('m') m = te.var('m')
l = te.var('l') l = te.var('l')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment