Commit cfdc5119 by Lianmin Zheng Committed by Tianqi Chen

delete init part when keeping trivial loop (#1031)

parent ca9ec009
......@@ -117,13 +117,13 @@ class OperationNode : public FunctionBaseNode {
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
* \param dom_map The domain map of all iteration domains.
* \param del_trivial_loop Whether eliminate trivial loop with extent of 1
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return A statement that add production and wraps consumer.
*/
virtual Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const = 0;
bool debug_keep_trivial_loop) const = 0;
static constexpr const char* _type_key = "Operation";
......@@ -163,7 +163,7 @@ class PlaceholderOpNode : public OperationNode {
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
......@@ -215,7 +215,7 @@ class ComputeOpNode : public OperationNode {
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
......@@ -287,7 +287,7 @@ class ScanOpNode : public OperationNode {
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
......@@ -351,7 +351,7 @@ class ExternOpNode : public OperationNode {
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
......
......@@ -29,10 +29,13 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \param del_trivial_loop Whether delete trivial loops with extent of 1
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 during lowering.
* This is a debug feature for dataflow/axis analysis.
* Note: If this is true, The lowered IR may be incorrect,
* because we will also delete the init part of reduction
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool del_trivial_loop);
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
/*!
* \brief To automatically inline the element-wise operations.
......
......@@ -27,7 +27,7 @@ TVM_REGISTER_API("schedule.AutoInlineInjective")
TVM_REGISTER_API("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 2)
*ret = ScheduleOps(args[0], args[1], true);
*ret = ScheduleOps(args[0], args[1], false);
else
*ret = ScheduleOps(args[0], args[1], args[2]);
});
......
......@@ -349,7 +349,7 @@ Stmt BuildStmt(Schedule sch,
// Phase 0
auto bounds = schedule::InferBound(sch);
auto stmt = schedule::ScheduleOps(sch, bounds, true);
auto stmt = schedule::ScheduleOps(sch, bounds, false);
stmt = ir::InjectPrefetch(stmt);
// Phase 1
......
......@@ -296,9 +296,9 @@ Stmt MakeProvide(const ComputeOpNode* op,
Stmt MakeComputeStmt(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
bool debug_keep_trivial_loop) {
// grab the nest structure
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, del_trivial_loop);
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
// Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
......@@ -319,7 +319,11 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = op::Substitute(provide, n.main_vmap);
provide = MergeNest(reduce, provide);
return MergeNest(common, Block::make(init, provide));
if (debug_keep_trivial_loop) {
return MergeNest(common, provide);
} else {
return MergeNest(common, Block::make(init, provide));
}
} else {
std::vector<Stmt> provides;
for (size_t i = 0; i < self->body.size(); ++i) {
......@@ -379,16 +383,16 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
ComputeType ctype = DetectComputeType(this, stage);
if (ctype == ComputeType::kCrossThreadReduction) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map, del_trivial_loop);
return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
} else if (ctype == ComputeType::kTensorize) {
return MakeTensorize(this, stage, dom_map, del_trivial_loop);
return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
} else {
return MakeComputeStmt(this, stage, dom_map, del_trivial_loop);
return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
}
}
......@@ -396,12 +400,13 @@ ComputeLoopNest ComputeLoopNest::make(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
bool debug_keep_trivial_loop) {
CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret;
// make main loop nest
ret.main_nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap, del_trivial_loop);
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
debug_keep_trivial_loop);
ret.main_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.main_vmap, false,
std::unordered_set<IterVar>());
......@@ -443,7 +448,7 @@ ComputeLoopNest ComputeLoopNest::make(
}
ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
skip_iter, &(ret.init_vmap), del_trivial_loop);
skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
ret.init_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) {
......
......@@ -37,14 +37,14 @@ struct ComputeLoopNest {
* \param self The pointer to compute op.
* \param stage The scxhedule stage.
* \param dom_map The domain map.
* \param del_trivial_loop Whether eliminate trivial loops with extent of 1
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return The constructed loop nest
*/
static ComputeLoopNest make(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop);
bool debug_keep_trivial_loop);
};
/*!
......@@ -52,27 +52,27 @@ struct ComputeLoopNest {
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \param del_trivial_loop Wheter eliminate trivial loops with extent of 1
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return The created statement.
*/
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop);
bool debug_keep_trivial_loop);
/*!
* \brief Build body of compute for tensorization.
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \param del_trivial_loop Wheter eliminate trivial loops with extent of 1
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return The created statement.
*/
Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop);
bool debug_keep_trivial_loop);
} // namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_
......@@ -14,14 +14,14 @@ Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
bool debug_keep_trivial_loop) {
Array<Expr> args;
for (IterVar iv : self->axis) {
args.push_back(iv->var);
}
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, del_trivial_loop);
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop);
auto conds = schedule::MakeBoundCheck(
stage, dom_map, value_map, false,
std::unordered_set<IterVar>());
......
......@@ -129,7 +129,7 @@ Stmt ExternOpNode::BuildRealize(
Stmt ExternOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
......
......@@ -24,7 +24,7 @@ MakeLoopNest(const Stage& stage,
bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map,
bool del_trivial_loop) {
bool debug_keep_trivial_loop) {
auto leaf_iter_vars = stage->leaf_iter_vars;
Stmt no_op = Evaluate::make(0);
// create the loop nest
......@@ -76,7 +76,7 @@ MakeLoopNest(const Stage& stage,
AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op));
}
}
if (del_trivial_loop && is_one(dom->extent)) {
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
nest[i + 1].emplace_back(
LetStmt::make(var, dom->min, no_op));
value_map[iv] = dom->min;
......@@ -131,7 +131,7 @@ MakeLoopNest(const Stage& stage,
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
if (del_trivial_loop && is_one(dom->extent)) {
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = var;
......
......@@ -29,7 +29,7 @@ using ir::MergeNest;
* \param new_loop_var Whether create new loop variable.
* \param skip_iter Whether skip certain iteration.
* \param p_value_map The result value of each IterVar.
* \param del_trivial_loop Whether eliminate trivial loops with extent of 1
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
*/
std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& stage,
......@@ -38,7 +38,7 @@ MakeLoopNest(const Stage& stage,
bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map,
bool del_trivial_loop);
bool debug_keep_trivial_loop);
/*!
* \brief Create a nest of if checking the predicates.
......
......@@ -79,7 +79,7 @@ Stmt PlaceholderOpNode::BuildRealize(
Stmt PlaceholderOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
bool debug_keep_trivial_loop) const {
return Stmt();
}
} // namespace tvm
......@@ -253,7 +253,7 @@ Stmt ScanOpNode::BuildRealize(
Stmt ScanOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt provide = AttrStmt::make(
stage->op, attr::scan_update_scope, this->scan_axis->var,
......@@ -271,7 +271,7 @@ Stmt ScanOpNode::BuildProvide(
std::unordered_map<IterVar, Expr> vmap;
std::unordered_set<IterVar> empty;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, empty, &vmap, del_trivial_loop);
stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
nest[begin_scan].push_back(init);
nest.push_back(
op::MakeIfNest(
......
......@@ -370,14 +370,14 @@ Stmt TransformUpdate(const Stage& stage,
Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
bool debug_keep_trivial_loop) {
std::unordered_map<IterVar, Range> out_dom;
std::unordered_map<Tensor, Array<Range> > in_region;
size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
TensorIntrin intrin = stage->iter_var_attrs.at(
stage->leaf_iter_vars[tloc])->tensor_intrin;
CHECK(intrin.defined());
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, del_trivial_loop);
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
VerifyTensorizeLoopNest(self, stage, n, tloc);
VerifyTensorizeBody(self, stage, out_dom, in_region, intrin);
// Start bind data.
......
......@@ -23,8 +23,8 @@ using namespace ir;
Stmt MakePipeline(const Stage& s,
const std::unordered_map<IterVar, Range>& dom_map,
Stmt consumer,
bool del_trivial_loop) {
Stmt producer = s->op->BuildProvide(s, dom_map, del_trivial_loop);
bool debug_keep_trivial_loop) {
Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
if (producer.defined()) {
producer = ProducerConsumer::make(s->op, true, producer);
}
......@@ -58,9 +58,9 @@ class InjectAttach : public IRMutator {
InjectAttach(const Stage& stage,
const Stage& attach_spec,
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop)
bool debug_keep_trivial_loop)
: stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
del_trivial_loop_(del_trivial_loop) {}
debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
......@@ -76,7 +76,7 @@ class InjectAttach : public IRMutator {
found_attach = true;
stmt = AttrStmt::make(
op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_));
MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
}
}
return stmt;
......@@ -91,8 +91,9 @@ class InjectAttach : public IRMutator {
const Stage& attach_spec_;
// domain map
const std::unordered_map<IterVar, Range>& dom_map_;
// whether delete trivial loops with extent of 1
bool del_trivial_loop_;
// Whether keep trivial loops with extent of 1 during lowering.
// This is a debug feature for dataflow/axis analysis
bool debug_keep_trivial_loop_;
};
// inject the operator's realization on the stmt.
......@@ -102,9 +103,9 @@ class InjectScanStep : public IRMutator {
const Operation& scan_op,
const std::unordered_map<IterVar, Range>& dom_map,
bool is_init,
bool del_trivial_loop)
bool debug_keep_trivial_loop)
: stage_(stage), scan_op_(scan_op),
dom_map_(dom_map), is_init_(is_init), del_trivial_loop_(del_trivial_loop) {}
dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
......@@ -118,7 +119,7 @@ class InjectScanStep : public IRMutator {
found_attach = true;
stmt = AttrStmt::make(
op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_));
MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
}
}
return stmt;
......@@ -135,8 +136,9 @@ class InjectScanStep : public IRMutator {
const std::unordered_map<IterVar, Range>& dom_map_;
// whether it is init.
bool is_init_;
// whether delete trivial loops with extent of 1
bool del_trivial_loop_;
// Whether keep trivial loops with extent of 1 during lowering.
// This is a debug feature for dataflow/axis analysis
bool debug_keep_trivial_loop_;
};
// Postprocessing of schedule op
......@@ -337,7 +339,7 @@ class SchedulePostProc : public IRMutator {
};
Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_, bool del_trivial_loop) {
Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
// scan init and scan updates
......@@ -372,14 +374,14 @@ Stmt ScheduleOps(
if (scan_init.count(s->op)) {
CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, del_trivial_loop);
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update
CHECK(body.defined());
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, del_trivial_loop);
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.update";
......@@ -387,11 +389,11 @@ Stmt ScheduleOps(
// do nothing
} else if (attach_spec->attach_type == kGroupRoot) {
CHECK(!s->group.defined());
body = MakePipeline(s, dom_map, body, del_trivial_loop);
body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
} else {
CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined());
InjectAttach mutator(s, attach_spec, dom_map, del_trivial_loop);
InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
body = mutator.Mutate(body);
CHECK(mutator.found_attach)
<< "did not find attachment point for " << s << " in "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment