Commit f8f02829 by Tianqi Chen Committed by GitHub

[SCHEDULE] Refactor bound inference logic (#41)

parent 5c07413c
......@@ -22,7 +22,7 @@ namespace schedule {
* \param sch The root schedule to infer all the bounds.
* \return the result bound of the iteration Variable
*/
Map<IterVar, Range> InferBound(Schedule sch);
Map<IterVar, Range> InferBound(const Schedule& sch);
/*!
* \brief Schedule s' dependent operations.
......
......@@ -432,7 +432,6 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);
IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
return IntSetEvaluator(dom_map).Eval(e);
......@@ -444,17 +443,12 @@ IntSet EvalSet(Expr e,
for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second;
}
IntSetEvaluator m(dmap);
return m.Eval(e);
return EvalSet(e, dmap);
}
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second;
}
IntSetEvaluator m(dmap);
const std::unordered_map<const Variable*, IntSet>& dom_map) {
IntSetEvaluator m(dom_map);
IntSet min_set = m.Eval(r->min);
IntSet ext_set = m.Eval(r->extent).cover_interval();
const Interval& ei = ext_set.as<IntervalSet>()->i;
......@@ -463,6 +457,15 @@ IntSet EvalSet(Range r,
return Combine<Add>(min_set, ext_set);
}
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second;
}
return EvalSet(r, dmap);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
p->stream << "interval-set["
......@@ -470,6 +473,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< op->i.max << ']';
});
} // namespace arith
} // namespace tvm
......@@ -103,6 +103,9 @@ IntSet EvalSet(Expr e,
*/
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map);
IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*!
* \brief Create an union set of all sets
......
......@@ -7,6 +7,8 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <tvm/schedule_pass.h>
#include <unordered_map>
#include <unordered_set>
#include "./graph.h"
#include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h"
......@@ -131,7 +133,6 @@ void PassUp(const FuseNode* s,
}
}
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& rebased,
......@@ -180,69 +181,35 @@ void PassUp(const Stage& s,
}
}
/*!
* \brief Pass the bound of tensor read
* to the corresponding bound of the IterVar of operation
* \param tensor The tensor to be passed.
* \param dim_bounds The read index set on each dimension.
* \param The result IterVar bound .
*/
void PassToOperation(
const Tensor& tensor,
const std::vector<IntSet>& dim_bounds,
std::unordered_map<IterVar, std::vector<IntSet> >* result) {
// This is a push style operation, given output bound, push to the op IterVar bound.
// It cannot handle complicated cases where op bound is coupled with bounds of
// all of its outputs, without having a simple communicative union relation.
//
// Eventually, we need to change the inference to be a Pull style inference
if (tensor->op.as<ComputeOpNode>()) {
auto root_iter_vars = tensor->op->root_iter_vars();
const ComputeOpNode* op = tensor->op.as<ComputeOpNode>();
CHECK_EQ(op->axis.size() + op->reduce_axis.size(), root_iter_vars.size());
for (size_t i = 0; i < op->axis.size(); ++i) {
(*result)[op->axis[i]].push_back(dim_bounds[i]);
}
// reduction.
for (size_t i = 0; i < op->reduce_axis.size(); ++i) {
(*result)[op->reduce_axis[i]].push_back(
IntSet::range(op->reduce_axis[i]->dom));
}
} else {
LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
}
}
/*! \brief temporary data structure to store Tensor domain */
struct TensorDom {
// constructor
explicit TensorDom(int ndim)
: data(ndim) {}
/*! \brief The domain data*/
std::vector<std::vector<IntSet> > data;
};
/*!
* \brief Recursively propagate bound
* \param post_order The propagation order.
* \brief Propagate bound to target
* \param dom_map The domain map to be propagated
* \param out The tensor set to be passed
* \return The result bound
*/
std::unordered_map<IterVar, IntSet>
BoundProp(const Array<Operation>& post_order,
std::unordered_map<IterVar, std::vector<IntSet> > *p_state) {
std::unordered_map<IterVar, IntSet> result;
for (size_t i = post_order.size(); i != 0; --i) {
Operation op = post_order[i - 1];
void BoundProp(const Operation& op,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom> *out) {
if (op.as<ComputeOpNode>()) {
for (auto iv : op->root_iter_vars()) {
CHECK(p_state->count(iv))
<< "Bound of root operator must exists";
CHECK(!result.count(iv));
result[iv] = Union(p_state->at(iv));
}
auto fvisit = [p_state, &result](const NodeRef& n) {
auto fvisit = [&dom_map, out](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (t->op.defined() && !t->op.as<PlaceholderOpNode>()) {
std::vector<IntSet> arg_bounds;
if (t->op.defined() && out->count(t)) {
TensorDom& dom = out->at(t);
for (size_t i = 0; i < t.ndim(); ++i) {
arg_bounds.push_back(EvalSet(call->args[i], result));
dom.data[i].push_back(EvalSet(call->args[i], dom_map));
}
PassToOperation(t, arg_bounds, p_state);
}
}
};
......@@ -252,10 +219,31 @@ BoundProp(const Array<Operation>& post_order,
} else {
LOG(FATAL) << "unknown operation mode " << op->type_key();
}
}
return result;
}
void InferOpBound(const Operation& op,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
if (op.as<ComputeOpNode>()) {
auto root_iter_vars = op->root_iter_vars();
const ComputeOpNode* compute = op.as<ComputeOpNode>();
const TensorDom& tdom = tmap.at(op.output(0));
for (size_t i = 0; i < compute->axis.size(); ++i) {
Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom);
CHECK(!rmap->count(compute->axis[i]));
(*rmap)[compute->axis[i]] = r;
}
for (size_t i = 0; i < compute->reduce_axis.size(); ++i) {
CHECK(!rmap->count(compute->reduce_axis[i]));
(*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
}
} else if (op.as<PlaceholderOpNode>()) {
// dp nothing
} else {
LOG(FATAL) << "unknown operation mode " << op->type_key();
}
}
// check if scope
inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
......@@ -267,7 +255,17 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank;
}
void InferBound(const Stage& stage,
// The map beteen tensor and operation it feeds ti
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
// AttachPath maps op-> a list of IterVar
// That represents the loop nest op sits in from inner most to outermost
using AttachPath = Map<Operation, Array<IterVar> >;
void InferRootBound(const Stage& stage,
const FeedGraph& feed_graph,
const AttachPath& attach_path,
std::unordered_map<IterVar, Range>* rmap) {
if (stage->attach_type == kInline) return;
if (stage->attach_type == kRoot || stage->attach_type == kNone) {
......@@ -277,15 +275,46 @@ void InferBound(const Stage& stage,
CHECK(!rmap->count(iv));
(*rmap)[iv] = iv->dom;
}
return;
}
if (stage->attach_type == kScope) {
// Infer root bounds for the attached node.
CHECK_EQ(stage->attach_type, kScope);
Stage parent = stage->attach_stage;
CHECK(parent.defined());
auto g = CreateReadGraph({parent->op});
auto post_order = PostDFSOrder({parent->op}, g);
std::unordered_map<IterVar, IntSet> up_state;
// The tensor domain.
std::unordered_map<Tensor, TensorDom> tmap;
// consumers other than parent
std::unordered_set<Operation> consumers;
// initialize the result
bool direct_consume_by_parent = false;
for (int i = 0; i < stage->op->num_outputs(); ++i) {
Tensor t = stage->op.output(i);
tmap.emplace(t, TensorDom(t.ndim()));
auto it = feed_graph.find(t);
if (it != feed_graph.end()) {
for (const Operation& op : it->second) {
if (op != parent->op) {
consumers.insert(op);
} else {
direct_consume_by_parent = true;
}
}
}
}
// The relax set
// Thie specifieds the iteration variables that need to be relaxed
// from the already inferred bounds.
std::unordered_map<const Variable*, IntSet> relax_set;
for (IterVar iv : attach_path.at(stage->op)) {
if (ScopeRelax(iv, stage->scope)) {
relax_set[iv->var.get()] = IntSet::range(rmap->at(iv));
}
}
if (direct_consume_by_parent) {
// Bound inference logics in parent.
std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) {
Range vrange = rmap->at(iv);
......@@ -305,48 +334,104 @@ void InferBound(const Stage& stage,
fix_value = false;
}
}
// get the bound of the root IterVars given the current condition
// get the bound of the root IterVars given current location.
PassUp(parent, *rmap, &up_state);
std::unordered_map<IterVar, std::vector<IntSet> > bp_state;
std::unordered_map<const Variable*, IntSet> dom_map;
for (auto iv : parent->op->root_iter_vars()) {
CHECK(up_state.count(iv));
bp_state[iv] = {up_state.at(iv)};
Range r = up_state.at(iv).cover_range(iv->dom);
if (relax_set.size() != 0) {
dom_map[iv->var.get()] = EvalSet(r, relax_set);
} else {
dom_map[iv->var.get()] = IntSet::range(r);
}
auto result = BoundProp(post_order, &bp_state);
// Set relaxation for the threads in parent.
Map<IterVar, IntSet> relax_set;
Stage s = stage;
while (s->attach_type == kScope) {
s = s->attach_stage;
for (auto iv : s->leaf_iter_vars) {
if (ScopeRelax(iv, stage->scope)) {
relax_set.Set(iv, IntSet::range(rmap->at(iv)));
}
// prop from parent.
BoundProp(parent->op, dom_map, &tmap);
}
// Bound prop by other consumers.
// To explain the the general logic, consider the example:
//
// for (i_outer, 0, 10) {
// producer
//
// for (i_inner, 0, 4) {
// consumer op
// }
// }
// - Get domain of each of consumer op, say [i_inner + i_outer*8, extent=4)
// - We need to relax it since the producer is attached at i_outer
// - Consumer's path is [i_inner, i_outer], then [i_inner] need to be relaxed
// - Traverse attach_path, relax until reaching the producer's attachment point.
for (const Operation& op : consumers) {
std::unordered_map<const Variable*, IntSet> dom_map;
bool found = false;
for (IterVar iv : attach_path.at(op)) {
if (iv == stage->attach_ivar) {
found = true; break;
}
for (auto iv : stage->op->root_iter_vars()) {
CHECK(result.count(iv));
CHECK(!rmap->count(iv));
Range r = result.at(iv).cover_range(iv->dom);
if (relax_set.size() != 0) {
r = EvalSet(r, relax_set).cover_range(iv->dom);
Range vrange = rmap->at(iv);
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
relax_set[iv->var.get()] = IntSet::range(vrange);
}
CHECK(found)
<< "Invalid Schedule, cannot find the producer " << stage->op
<< " along the loop nest specified by compute_at of consumer " << op;
for (auto iv : op->root_iter_vars()) {
Range r = rmap->at(iv);
dom_map[iv->var.get()] = EvalSet(r, relax_set);
}
(*rmap)[iv] = r;
BoundProp(op, dom_map, &tmap);
}
InferOpBound(stage->op, tmap, rmap);
}
FeedGraph CreateFeedGraph(const Schedule& sch) {
auto g = CreateReadGraph(sch->roots);
FeedGraph fg;
for (auto kv : g) {
for (Tensor t : kv.second) {
fg[t].push_back(kv.first);
}
// get range of all child iter vars.
PassDown(stage, rmap);
}
return fg;
}
// Create AttachPath that maps op-> a list of IterVar
// That represents the loop nest op sits in from inner most to outermost
AttachPath CreateAttachPath(const Schedule& sch) {
AttachPath ret;
for (Stage stage : sch->stages) {
Array<IterVar> path;
for (Stage s = stage; s->attach_type == kScope;) {
IterVar attach_ivar = s->attach_ivar;
s = s->attach_stage;
bool start_attach = false;
for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = s->leaf_iter_vars[i - 1];
if (iv == attach_ivar) start_attach = true;
if (start_attach) path.push_back(iv);
}
CHECK(start_attach)
<< "Invalid Schedule: cannot find attach point " << attach_ivar
<< " in the schedule of " << s->op;
}
ret.Set(stage->op, path);
}
return ret;
}
Map<IterVar, Range> InferBound(Schedule sch) {
Map<IterVar, Range> InferBound(const Schedule& sch) {
FeedGraph feed_graph = CreateFeedGraph(sch);
AttachPath attach_path = CreateAttachPath(sch);
std::unordered_map<IterVar, Range> ret;
// reverse post DFS order, from out most stage to the innermost
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1];
InferBound(stage, &ret);
const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, feed_graph, attach_path, &ret);
// pass down to get bound of all iter vars.
PassDown(stage, &ret);
}
return Map<IterVar, Range>(ret.begin(), ret.end());
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment