Commit ea9c1c59 by Tianqi Chen Committed by GitHub

[SCHEDULE] More reliable bound inference on threading. (#84)

parent 3ac94439
......@@ -15,33 +15,56 @@
namespace tvm {
namespace schedule {
using runtime::ThreadScope;
using runtime::StorageScope;
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
/*! \brief Attachment path */
AttachPath attach_path;
/*! \brief The bind map */
std::unordered_map<IterVar, IterVar> bind_map;
/*! \brief map from op to stage */
std::unordered_map<const Node*, Stage> op2stage_;
};
// check if scope
inline bool ScopeRelax(const IterVar& ivar,
const std::unordered_map<IterVar, IterVar>& bind_map,
const std::string& scope) {
using runtime::ThreadScope;
using runtime::StorageScope;
auto it = bind_map.find(ivar);
IterVar iv = ivar;
if (it != bind_map.end()) {
iv = it->second;
bool NeedRelax(const IterVar& iv,
bool found_attach,
const std::unordered_map<IterVar, IterVar>& bind_map,
const runtime::StorageScope& scope) {
auto it = bind_map.find(iv);
const std::string& tag = (
it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
if (tag.length() == 0 || tag == "pipeline") {
return !found_attach;
}
if (iv->thread_tag.length() == 0) return false;
if (scope.length() == 0) return false;
return scope.rank <= ThreadScope::make(tag).rank;
}
return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank;
// infer storage scope, if not given
StorageScope InferStorageScope(
const Stage& stage, const GraphContext& ctx) {
if (stage->scope.length() != 0) {
return StorageScope::make(stage->scope);
}
int max_rank = 0;
for (IterVar iv : ctx.attach_path.at(stage->op)) {
auto it = ctx.bind_map.find(iv);
const std::string& tag = (
it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
if (tag != "pipeline" && tag.length() != 0) {
max_rank = std::max(max_rank, ThreadScope::make(tag).rank + 1);
}
}
StorageScope s; s.rank = max_rank;
return s;
}
void InferRootBound(const Stage& stage,
const GraphContext& ctx,
const AttachPath& attach_path,
const std::unordered_map<IterVar, IterVar>& bind_map,
std::unordered_map<IterVar, Range>* rmap) {
CHECK_NE(stage->attach_type, kInline)
<< "call schedule.normalize before scheduleops";
......@@ -59,73 +82,78 @@ void InferRootBound(const Stage& stage,
}
return;
}
// parent stage, if any
Stage parent;
Stage attach_spec = stage.GetAttachSpec();
if (attach_spec->attach_type == kScope ||
attach_spec->attach_type == kScanUpdate) {
parent = attach_spec->attach_stage;
}
// The tensor domain.
std::unordered_map<Tensor, TensorDom> tmap;
// consumers other than parent
// The consumers of the op.
std::unordered_set<Operation> consumers;
// initialize the result
bool direct_consume_by_parent = false;
for (int i = 0; i < stage->op->num_outputs(); ++i) {
Tensor t = stage->op.output(i);
tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
auto it = ctx.feed_graph.find(t);
if (it != ctx.feed_graph.end()) {
for (const Operation& op : it->second) {
if (!parent.defined() || op != parent->op) {
consumers.insert(op);
} else {
direct_consume_by_parent = true;
}
consumers.insert(op);
}
} else {
LOG(INFO) << "not in feed graph consumer = " << stage->op;
}
}
// The relax set
// Thie specifieds the iteration variables that need to be relaxed
// from the already inferred bounds.
std::unordered_map<const Variable*, IntSet> relax_set;
for (IterVar iv : attach_path.at(stage->op)) {
if (ScopeRelax(iv, bind_map, stage->scope)) {
relax_set[iv->var.get()] = IntSet::range(rmap->at(iv));
}
}
if (direct_consume_by_parent) {
// Bound inference logics in parent.
// storage scope.
runtime::StorageScope scope = InferStorageScope(stage, ctx);
// Bound prop by other consumers.
// - Compute bound by relaxation rules: NeedRelax
// - For normal index, use relative location of loop nest./
// - For thread index, use the thread scope.
//
Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
// The parent set.
for (const Operation& op : consumers) {
std::unordered_map<const Variable*, IntSet> relax_set;
std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) {
bool found_attach = false;
CHECK(ctx.op2stage_.count(op.get()));
const Stage& op_stage = ctx.op2stage_.at(op.get());
// Consumer nest
for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = op_stage->leaf_iter_vars[i - 1];
if (stage_attach.size() != 0 && iv == stage_attach[0]) {
found_attach = true;
}
auto it = rmap->find(iv);
CHECK(it != rmap->end());
Range vrange = it->second;
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< " call schedule.normalize to achieve this. "
<< " stage=" << parent << ", vrange=" << vrange->min;
// special optimization to remove trivial loop
const Range& vrange = it->second;
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
} else if (fix_value && !ScopeRelax(iv, bind_map, stage->scope)) {
} else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< " call schedule.normalize to achieve this. ";
up_state[iv] = IntSet::single_point(iv->var);
} else {
up_state[iv] = IntSet::range(vrange);
}
if (attach_spec->attach_ivar == iv) {
fix_value = false;
}
// Consumer's attach nest
for (IterVar iv : ctx.attach_path.at(op)) {
if (stage_attach.size() != 0 && iv == stage_attach[0]) {
found_attach = true;
}
Range vrange = rmap->at(iv);
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
relax_set[iv->var.get()] = IntSet::range(vrange);
}
}
// get the bound of the root IterVars given current location.
PassUpDomain(parent, *rmap, &up_state);
CHECK(found_attach || stage_attach.size() == 0)
<< "Invalid Schedule, cannot find the producer " << stage->op
<< " along the loop nest specified by compute_at of consumer " << op;
// Get the domain of the consumer
PassUpDomain(op_stage, *rmap, &up_state);
// Relax if needed.
std::unordered_map<const Variable*, IntSet> dom_map;
for (auto iv : parent->op->root_iter_vars()) {
for (auto iv : op->root_iter_vars()) {
Range r;
if (up_state.count(iv)) {
r = up_state.at(iv).cover_range(iv->dom);
......@@ -138,70 +166,35 @@ void InferRootBound(const Stage& stage,
dom_map[iv->var.get()] = IntSet::range(r);
}
}
// prop from parent.
parent->op->PropBoundToInputs(parent->op, dom_map, &tmap);
}
// Bound prop by other consumers.
// To explain the the general logic, consider the example:
//
// for (i_outer, 0, 10) {
// producer
//
// for (i_inner, 0, 4) {
// consumer op
// }
// }
// - Get domain of each of consumer op, say [i_inner + i_outer*8, extent=4)
// - We need to relax it since the producer is attached at i_outer
// - Consumer's path is [i_inner, i_outer], then [i_inner] need to be relaxed
// - Traverse attach_path, relax until reaching the producer's attachment point.
for (const Operation& op : consumers) {
std::unordered_map<const Variable*, IntSet> dom_map;
bool found = false;
Array<IterVar> attach = attach_path.at(stage->op);
for (IterVar iv : attach_path.at(op)) {
if (attach.size() != 0 && iv == attach[0]) {
found = true; break;
}
Range vrange = rmap->at(iv);
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
relax_set[iv->var.get()] = IntSet::range(vrange);
}
CHECK(found || attach.size() == 0)
<< "Invalid Schedule, cannot find the producer " << stage->op
<< " along the loop nest specified by compute_at of consumer " << op;
for (auto iv : op->root_iter_vars()) {
Range r = rmap->at(iv);
dom_map[iv->var.get()] = EvalSet(r, relax_set);
}
op->PropBoundToInputs(op, dom_map, &tmap);
}
stage->op->GatherBound(stage->op, tmap, rmap);
}
Map<IterVar, Range> InferBound(const Schedule& sch) {
// Prepare context
GraphContext ctx;
Array<Operation> roots;
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
std::unordered_map<IterVar, IterVar> bind_map;
ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
for (Stage stage : sch->stages) {
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(!bind_map.count(kv.first));
bind_map[kv.first] = kv.second->bind_thread;
CHECK(!ctx.bind_map.count(kv.first));
ctx.bind_map[kv.first] = kv.second->bind_thread;
}
}
ctx.op2stage_[stage->op.get()] = stage;
}
GraphContext ctx;
ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
AttachPath attach_path = CreateAttachPath(sch);
ctx.attach_path = CreateAttachPath(sch);
// Run inference.
std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, ctx, attach_path, bind_map, &ret);
InferRootBound(stage, ctx, &ret);
// pass down to get bound of all iter vars.
PassDownDomain(stage, &ret);
for (IterVar iv : stage->env_threads) {
......
......@@ -154,24 +154,9 @@ Tensor Schedule::cache_write(const Tensor& tensor,
void RebaseNonZeroMinLoop(const Schedule& sch) {
std::unordered_map<IterVar, IterVar> rebase_map;
std::unordered_map<const Node*, int> attach_mark;
for (Stage s : sch->stages) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
if (s->op.as<ScanOpNode>()) {
attach_mark[s.get()] = 1;
}
}
for (Stage s : sch->groups) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
}
if (s->attach_type == kInlinedAlready) continue;
for (Stage s : sch->stages) {
if (!attach_mark.count(s.get())) continue;
auto root_iter_vars = s->op->root_iter_vars();
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) {
......@@ -201,16 +186,6 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
}
}
void SetScanAttach(const Schedule& sch) { // NOLINT(*)
for (Stage stage : sch->stages) {
if (stage->attach_type == kScanUpdate) {
const Stage& parent = stage->attach_stage;
stage->attach_ivar =
parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1];
}
}
}
void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache();
std::vector<Expr> new_body(sch->stages.size());
......@@ -262,9 +237,8 @@ void InjectInline(ScheduleNode* sch) {
}
void Schedule::normalize() {
RebaseNonZeroMinLoop(*this);
SetScanAttach(*this);
InjectInline(operator->());
RebaseNonZeroMinLoop(*this);
}
// Handle reduction factor.
......
......@@ -148,7 +148,37 @@ def test_bound_nest_group():
assert bounds[x1.op.axis[0]].extent.value == 1
assert bounds[x1.op.axis[1]].extent == n
def test_bound_nest_thread():
m = tvm.Var('m')
A = tvm.placeholder((m), name='A')
A1 = tvm.compute((m,), lambda i: A[i], name='A1')
A2 = tvm.compute((m,), lambda i: A1[i] + 2, name='A2')
A3 = tvm.compute((m,), lambda i: A2[i] + 3, name='A3')
s = tvm.Schedule(A3.op)
s[A2].set_scope("shared")
s[A1].set_scope("local")
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
bx, tx = s[A3].split(A3.op.axis[0], factor=32)
s[A3].bind(bx, block_x)
s[A3].bind(tx, thread_x)
s[A2].compute_at(s[A3], tx)
_, xi = s[A2].split(A2.op.axis[0], nparts=1)
s[A2].bind(xi, thread_x)
s[A1].compute_at(s[A3], tx)
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[A1.op.axis[0]].extent.value==1)
assert(bounds[A2.op.axis[0]].extent.value==32)
assert(bounds[A3.op.axis[0]].extent == m)
if __name__ == "__main__":
test_bound_nest_thread()
test_bound1()
test_bound_nest_group()
test_bound_group_schedule()
test_bound_scan()
......@@ -156,5 +186,4 @@ if __name__ == "__main__":
test_bound_rfactor()
test_bound_blur()
test_bound_conv1d()
test_bound1()
test_bound2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment