Commit d56c777a by Lianmin Zheng Committed by Tianqi Chen

support to keep trivial loops with extent of 1 (#877)

parent b21aee7d
......@@ -117,11 +117,13 @@ class OperationNode : public FunctionBaseNode {
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
* \param dom_map The domain map of all iteration domains.
* \param del_trivial_loop Whether eliminate trivial loop with extent of 1
* \return A statement that add production and wraps consumer.
*/
virtual Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const = 0;
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const = 0;
static constexpr const char* _type_key = "Operation";
......@@ -160,7 +162,8 @@ class PlaceholderOpNode : public OperationNode {
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final;
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
......@@ -211,7 +214,8 @@ class ComputeOpNode : public OperationNode {
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final;
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
......@@ -282,7 +286,8 @@ class ScanOpNode : public OperationNode {
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final;
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
......@@ -345,7 +350,8 @@ class ExternOpNode : public OperationNode {
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final;
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
......
......@@ -29,9 +29,10 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \param del_trivial_loop Whether delete trivial loops with extent of 1
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool del_trivial_loop);
/*!
* \brief To automatically inline the element-wise operations.
......
......@@ -24,6 +24,14 @@ TVM_REGISTER_API("schedule.AutoInlineInjective")
AutoInlineInjective(args[0]);
});
TVM_REGISTER_API("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 2)
*ret = ScheduleOps(args[0], args[1], true);
else
*ret = ScheduleOps(args[0], args[1], args[2]);
});
#define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
......@@ -43,7 +51,6 @@ REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS2(ScheduleOps);
} // namespace schedule
} // namespace tvm
......@@ -211,7 +211,7 @@ Stmt BuildStmt(Schedule sch,
// Phase 0
auto bounds = schedule::InferBound(sch);
auto stmt = schedule::ScheduleOps(sch, bounds);
auto stmt = schedule::ScheduleOps(sch, bounds, true);
stmt = ir::InjectPrefetch(stmt);
// Phase 1
......
......@@ -305,9 +305,10 @@ Stmt MakeProvide(const ComputeOpNode* op,
Stmt MakeComputeStmt(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
// grab the nest structure
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map);
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, del_trivial_loop);
// Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
......@@ -387,28 +388,30 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
ComputeType ctype = DetectComputeType(this, stage);
if (ctype == ComputeType::kCrossThreadReduction) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
return MakeCrossThreadReduction(this, stage, dom_map, del_trivial_loop);
} else if (ctype == ComputeType::kTensorize) {
return MakeTensorize(this, stage, dom_map);
return MakeTensorize(this, stage, dom_map, del_trivial_loop);
} else {
return MakeComputeStmt(this, stage, dom_map);
return MakeComputeStmt(this, stage, dom_map, del_trivial_loop);
}
}
ComputeLoopNest ComputeLoopNest::make(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret;
// make main loop nest
ret.main_nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap);
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap, del_trivial_loop);
ret.main_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.main_vmap, false,
std::unordered_set<IterVar>());
......@@ -450,7 +453,7 @@ ComputeLoopNest ComputeLoopNest::make(
}
ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
skip_iter, &(ret.init_vmap));
skip_iter, &(ret.init_vmap), del_trivial_loop);
ret.init_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) {
......
......@@ -37,12 +37,14 @@ struct ComputeLoopNest {
* \param self The pointer to compute op.
* \param stage The scxhedule stage.
* \param dom_map The domain map.
* \param del_trivial_loop Whether eliminate trivial loops with extent of 1
* \return The constructed loop nest
*/
static ComputeLoopNest make(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop);
};
/*!
......@@ -50,23 +52,27 @@ struct ComputeLoopNest {
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \param del_trivial_loop Wheter eliminate trivial loops with extent of 1
* \return The created statement.
*/
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop);
/*!
* \brief Build body of compute for tensorization.
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \param del_trivial_loop Wheter eliminate trivial loops with extent of 1
* \return The created statement.
*/
Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop);
} // namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_
......@@ -13,14 +13,15 @@ using namespace ir;
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
Array<Expr> args;
for (IterVar iv : self->axis) {
args.push_back(iv->var);
}
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, del_trivial_loop);
auto conds = schedule::MakeBoundCheck(
stage, dom_map, value_map, false,
std::unordered_set<IterVar>());
......
......@@ -128,7 +128,8 @@ Stmt ExternOpNode::BuildRealize(
Stmt ExternOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
......
......@@ -23,7 +23,8 @@ MakeLoopNest(const Stage& stage,
size_t begin_iter_pos,
bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map) {
std::unordered_map<IterVar, Expr>* p_value_map,
bool del_trivial_loop) {
auto leaf_iter_vars = stage->leaf_iter_vars;
Stmt no_op = Evaluate::make(0);
// create the loop nest
......@@ -75,7 +76,7 @@ MakeLoopNest(const Stage& stage,
AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op));
}
}
if (is_one(dom->extent)) {
if (del_trivial_loop && is_one(dom->extent)) {
nest[i + 1].emplace_back(
LetStmt::make(var, dom->min, no_op));
value_map[iv] = dom->min;
......@@ -130,7 +131,7 @@ MakeLoopNest(const Stage& stage,
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
if (is_one(dom->extent)) {
if (del_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = var;
......
......@@ -29,6 +29,7 @@ using ir::MergeNest;
* \param new_loop_var Whether create new loop variable.
* \param skip_iter Whether skip certain iteration.
* \param p_value_map The result value of each IterVar.
* \param del_trivial_loop Whether eliminate trivial loops with extent of 1
*/
std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& stage,
......@@ -36,7 +37,8 @@ MakeLoopNest(const Stage& stage,
size_t begin_iter_pos,
bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map);
std::unordered_map<IterVar, Expr>* p_value_map,
bool del_trivial_loop);
/*!
* \brief Create a nest of if checking the predicates.
......
......@@ -78,7 +78,8 @@ Stmt PlaceholderOpNode::BuildRealize(
Stmt PlaceholderOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
return Stmt();
}
} // namespace tvm
......@@ -252,7 +252,8 @@ Stmt ScanOpNode::BuildRealize(
Stmt ScanOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt provide = AttrStmt::make(
stage->op, attr::scan_update_scope, this->scan_axis->var,
......@@ -270,7 +271,7 @@ Stmt ScanOpNode::BuildProvide(
std::unordered_map<IterVar, Expr> vmap;
std::unordered_set<IterVar> empty;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, empty, &vmap);
stage, dom_map, 0, false, empty, &vmap, del_trivial_loop);
nest[begin_scan].push_back(init);
nest.push_back(
op::MakeIfNest(
......
......@@ -369,14 +369,15 @@ Stmt TransformUpdate(const Stage& stage,
Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
std::unordered_map<IterVar, Range> out_dom;
std::unordered_map<Tensor, Array<Range> > in_region;
size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
TensorIntrin intrin = stage->iter_var_attrs.at(
stage->leaf_iter_vars[tloc])->tensor_intrin;
CHECK(intrin.defined());
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map);
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, del_trivial_loop);
VerifyTensorizeLoopNest(self, stage, n, tloc);
VerifyTensorizeBody(self, stage, out_dom, in_region, intrin);
// Start bind data.
......
......@@ -22,8 +22,9 @@ using namespace ir;
Stmt MakePipeline(const Stage& s,
const std::unordered_map<IterVar, Range>& dom_map,
Stmt consumer) {
Stmt producer = s->op->BuildProvide(s, dom_map);
Stmt consumer,
bool del_trivial_loop) {
Stmt producer = s->op->BuildProvide(s, dom_map, del_trivial_loop);
if (producer.defined()) {
producer = ProducerConsumer::make(s->op, true, producer);
}
......@@ -68,7 +69,7 @@ class InjectAttach : public IRMutator {
found_attach = true;
stmt = AttrStmt::make(
op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body));
MakePipeline(stage_, dom_map_, op->body, true));
}
}
return stmt;
......@@ -107,7 +108,7 @@ class InjectScanStep : public IRMutator {
found_attach = true;
stmt = AttrStmt::make(
op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body));
MakePipeline(stage_, dom_map_, op->body, true));
}
}
return stmt;
......@@ -324,7 +325,7 @@ class SchedulePostProc : public IRMutator {
};
Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_) {
Schedule sch, Map<IterVar, Range> dom_map_, bool del_trivial_loop) {
Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
// scan init and scan updates
......@@ -374,7 +375,7 @@ Stmt ScheduleOps(
// do nothing
} else if (attach_spec->attach_type == kGroupRoot) {
CHECK(!s->group.defined());
body = MakePipeline(s, dom_map, body);
body = MakePipeline(s, dom_map, body, del_trivial_loop);
} else {
CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment