Unverified Commit 2421a854 by yongfeng-nv Committed by GitHub

Set split node's range to minimum of ext and split factor or split nparts, but…

Set split node's range to minimum of ext and split factor or split nparts, but only when PassDownDomain is called with allow_missing == false, i.e. by InferBound. Add a helper PassUpThreadBinding() to get a map telling whether an IterVar has at least one leaf IterVar deriving from it binding to a thread. Add two unit tests. (#5044)
parent 683ed4a3
......@@ -51,17 +51,66 @@ void Update(std::unordered_map<IterVar, Range>* p_state,
}
}
/*!
* \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to
* a thread.
*
* \param stage The stage to operate on.
* \param p_state The propagation result of each IterVar.
*/
void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>* p_state) {
auto bound_to_thread = [&stage](const IterVar& iv) {
bool bound = false;
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end()) {
bound = (*it).second->bind_thread.defined();
}
return bound;
};
auto& state = *p_state;
// Fill p_state with leaf itervars
for (const IterVar& iv : stage->leaf_iter_vars) {
state[iv] = bound_to_thread(iv);
}
// Traverse the graph bottom-up to propagate thread binding information
for (size_t i = stage->relations.size(); i != 0; --i) {
IterVarRelation rel = stage->relations[i - 1];
if (const SplitNode* s = rel.as<SplitNode>()) {
state[s->parent] = state[s->inner] || state[s->outer];
} else if (const FuseNode* s = rel.as<FuseNode>()) {
state[s->inner] = state[s->fused];
state[s->outer] = state[s->fused];
} else if (const RebaseNode* s = rel.as<RebaseNode>()) {
state[s->parent] = state[s->rebased];
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
arith::Analyzer* actx,
bool allow_missing) {
auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
if (actx->CanProve(indexmod(a, b) == 0)) {
return actx->Simplify(indexdiv(a, b));
}
return actx->Simplify(indexdiv(a + (b - 1), b));
};
auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) {
if (actx->CanProve(a < b)) {
return actx->Simplify(a);
}
return actx->Simplify(b);
};
std::unordered_map<IterVar, bool> dominating_thread;
PassUpThreadBinding(stage, &dominating_thread);
auto& state = *p_state;
// forwar iteration on relations
for (IterVarRelation rel : stage->relations) {
......@@ -72,14 +121,35 @@ void PassDownDomain(const Stage& stage,
}
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
// Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the
// following conditions are met:
// 1. No leaf IterVar derived from iv binds to any thread. People may use split
// to force an IterVar extent to match the number of allocated threads to fuse stages
// that require different number of threads. We don't want to change these extents.
// 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound,
// rather than by an early compiler phase, such as rfactor(). We don't want to tighten an
// IterVar in an early phase allowing missing IterVars, because it may bind to a thread later.
// 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one
// zero-sized dimension. Split creates iv with a positive extent to avoid zero-extent
// IterVar. We don't touch it.
auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) {
return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent)
? factor_or_nparts
: minimum_or_later(range_parent->extent, factor_or_nparts);
};
if (r->factor.defined()) {
Update(p_state, r->inner,
Range::make_by_min_extent(0, r->factor), actx);
Range::make_by_min_extent(
0, resolve_min_extent_for_split(r->inner, r->factor)),
actx);
Update(p_state, r->outer,
Range::make_by_min_extent(
0, ceil_div(range_parent->extent, r->factor)), actx);
} else {
Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
Update(p_state, r->outer,
Range::make_by_min_extent(
0, resolve_min_extent_for_split(r->outer, r->nparts)),
actx);
Update(p_state, r->inner,
Range::make_by_min_extent(
0, ceil_div(range_parent->extent, r->nparts)), actx);
......
......@@ -70,6 +70,32 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)
def test_bound_split_ext_less_than_factor():
m = 8
I = te.placeholder((m,), name='I')
EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
s = te.create_schedule([E.op])
xo, xi = s[E].split(s[E].op.axis[0], factor = 32)
s[EF].compute_at(s[E], xo)
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xi].extent.value == m
def test_bound_split_ext_less_than_naprts():
m = 8
I = te.placeholder((m,), name='I')
EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
s = te.create_schedule([E.op])
xo, xi = s[E].split(s[E].op.axis[0], nparts = 32)
s[EF].compute_at(s[E], xo)
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xo].extent.value == m
def test_bound_split_divisible():
m = te.var('m')
l = te.var('l')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment