Commit 5f1816db by xqdan Committed by Tianqi Chen

enable partition const loop with build flag (#732)

* [SCHEDULE]enable partition const loop with build flag (#719)

    * enable partition loop with build flag

    * add a testcase, and modify LoopPartition related cases

*     * add document for split_const_loop
parent 66fa0c3d
...@@ -116,6 +116,9 @@ struct BuildConfig { ...@@ -116,6 +116,9 @@ struct BuildConfig {
/*! \brief Whether to detect global barrier */ /*! \brief Whether to detect global barrier */
bool detect_global_barrier = false; bool detect_global_barrier = false;
/*! \brief Whether to partition const loop */
bool partition_const_loop = false;
BuildConfig() { BuildConfig() {
} }
}; };
......
...@@ -289,9 +289,10 @@ Stmt StorageRewrite(Stmt stmt); ...@@ -289,9 +289,10 @@ Stmt StorageRewrite(Stmt stmt);
/*! /*!
* \brief partition loops in the stmt * \brief partition loops in the stmt
* \param stmt The stmt to do loop partition * \param stmt The stmt to do loop partition
* \param split_const_loop flag to enable partition for const loop
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt LoopPartition(Stmt stmt); Stmt LoopPartition(Stmt stmt, bool split_const_loop);
/*! /*!
* \brief Detect and insert sync points to co-processor. * \brief Detect and insert sync points to co-processor.
......
...@@ -32,6 +32,7 @@ class BuildConfig(object): ...@@ -32,6 +32,7 @@ class BuildConfig(object):
"auto_unroll_max_extent": 0, "auto_unroll_max_extent": 0,
"unroll_explicit": True, "unroll_explicit": True,
"detect_global_barrier": False, "detect_global_barrier": False,
"partition_const_loop": False,
"offset_factor": 0, "offset_factor": 0,
"data_alignment": -1, "data_alignment": -1,
"restricted_func": True, "restricted_func": True,
...@@ -88,6 +89,9 @@ def build_config(**kwargs): ...@@ -88,6 +89,9 @@ def build_config(**kwargs):
detect_global_barrier: bool, default=True detect_global_barrier: bool, default=True
Whether detect global barrier. Whether detect global barrier.
partition_const_loop: bool, default=False
Whether partition const loop
data_alignment: int, optional data_alignment: int, optional
The alignment of data pointer in bytes. The alignment of data pointer in bytes.
If -1 is passed, the alignment will be set to TVM's internal default. If -1 is passed, the alignment will be set to TVM's internal default.
...@@ -219,7 +223,7 @@ def lower(sch, ...@@ -219,7 +223,7 @@ def lower(sch,
stmt = f(stmt) stmt = f(stmt)
# Phase 2 # Phase 2
if not simple_mode: if not simple_mode:
stmt = ir_pass.LoopPartition(stmt) stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
......
...@@ -119,7 +119,7 @@ REGISTER_PASS1(LowerStorageAccessInfo); ...@@ -119,7 +119,7 @@ REGISTER_PASS1(LowerStorageAccessInfo);
REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch); REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS2(InjectDoubleBuffer); REGISTER_PASS2(InjectDoubleBuffer);
REGISTER_PASS1(LoopPartition); REGISTER_PASS2(LoopPartition);
REGISTER_PASS1(RemoveNoOp); REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline); REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LiftAttrScope); REGISTER_PASS2(LiftAttrScope);
......
...@@ -208,7 +208,7 @@ Stmt BuildStmt(Schedule sch, ...@@ -208,7 +208,7 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::StorageFlatten(stmt, out_binds, 64); stmt = ir::StorageFlatten(stmt, out_binds, 64);
stmt = ir::CanonicalSimplify(stmt); stmt = ir::CanonicalSimplify(stmt);
if (loop_partition) { if (loop_partition) {
stmt = ir::LoopPartition(stmt); stmt = ir::LoopPartition(stmt, config.partition_const_loop);
} }
stmt = ir::VectorizeLoop(stmt); stmt = ir::VectorizeLoop(stmt);
stmt = ir::InjectVirtualThread(stmt); stmt = ir::InjectVirtualThread(stmt);
......
...@@ -45,10 +45,12 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) { ...@@ -45,10 +45,12 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
class CandidateSelector final : public IRVisitor { class CandidateSelector final : public IRVisitor {
public: public:
using VarIsUsed = bool; using VarIsUsed = bool;
CandidateSelector() {} explicit CandidateSelector(bool split_const_loop)
: split_const_loop_(split_const_loop) {}
void Visit_(const For* op) { void Visit_(const For* op) {
if (!is_const(op->min) || !is_const(op->extent)) { // partition const loop when sets split_const_loop_
if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) {
const Variable* var = op->loop_var.get(); const Variable* var = op->loop_var.get();
record_.insert({var, false}); record_.insert({var, false});
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
...@@ -67,7 +69,7 @@ class CandidateSelector final : public IRVisitor { ...@@ -67,7 +69,7 @@ class CandidateSelector final : public IRVisitor {
CHECK(iv); CHECK(iv);
Var var = iv->var; Var var = iv->var;
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
if ((scope.rank == 0) && !is_const(op->value)) { if ((scope.rank == 0) && (!is_const(op->value) || split_const_loop_)) {
record_.insert({var.get(), false}); record_.insert({var.get(), false});
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
if (record_.at(var.get()) && !no_split_) { if (record_.at(var.get()) && !no_split_) {
...@@ -115,6 +117,7 @@ class CandidateSelector final : public IRVisitor { ...@@ -115,6 +117,7 @@ class CandidateSelector final : public IRVisitor {
private: private:
bool in_likely_{false}; bool in_likely_{false};
bool no_split_{false}; bool no_split_{false};
bool split_const_loop_{false};
std::unordered_map<const Variable*, VarIsUsed> record_; std::unordered_map<const Variable*, VarIsUsed> record_;
}; };
...@@ -392,8 +395,8 @@ class RemoveLikelyTags : public IRMutator { ...@@ -392,8 +395,8 @@ class RemoveLikelyTags : public IRMutator {
} }
}; };
Stmt LoopPartition(Stmt stmt) { Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
CandidateSelector selector; CandidateSelector selector(split_const_loop);
selector.Visit(stmt); selector.Visit(stmt);
stmt = LoopPartitioner(selector.candidates).Mutate(stmt); stmt = LoopPartitioner(selector.candidates).Mutate(stmt);
stmt = RemoveLikelyTags().Mutate(stmt); stmt = RemoveLikelyTags().Mutate(stmt);
......
...@@ -27,7 +27,7 @@ def test_add_pipeline(): ...@@ -27,7 +27,7 @@ def test_add_pipeline():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Db = tvm.decl_buffer(D.shape, D.dtype, name='D') Db = tvm.decl_buffer(D.shape, D.dtype, name='D')
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
......
...@@ -19,7 +19,7 @@ def lower(sch, args): ...@@ -19,7 +19,7 @@ def lower(sch, args):
sch = sch.normalize() sch = sch.normalize()
bounds = tvm.schedule.InferBound(sch) bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds) stmt = tvm.schedule.ScheduleOps(sch, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64) stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.VectorizeLoop(stmt) stmt = tvm.ir_pass.VectorizeLoop(stmt)
...@@ -37,7 +37,22 @@ def test_basic(): ...@@ -37,7 +37,22 @@ def test_basic():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
def test_const_loop():
n = 21
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')
T = tvm.compute((n, ), lambda i: A[i]+B[i])
s = tvm.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first)) assert('if' not in str(stmt.body.body.body.first))
...@@ -53,7 +68,7 @@ def test_multi_loop(): ...@@ -53,7 +68,7 @@ def test_multi_loop():
with ib.else_scope(): with ib.else_scope():
ib.emit(tvm.make.Evaluate(n)) ib.emit(tvm.make.Evaluate(n))
stmt = ib.get() stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
...@@ -73,7 +88,7 @@ def test_multi_if(): ...@@ -73,7 +88,7 @@ def test_multi_if():
with ib.else_scope(): with ib.else_scope():
ib.emit(tvm.make.Evaluate(n)) ib.emit(tvm.make.Evaluate(n))
stmt = ib.get() stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.first)) assert('if' not in str(stmt.body.first))
...@@ -92,7 +107,7 @@ def test_thread_axis(): ...@@ -92,7 +107,7 @@ def test_thread_axis():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first)) assert('if' not in str(stmt.body.body.body.first))
...@@ -127,7 +142,7 @@ def test_select(): ...@@ -127,7 +142,7 @@ def test_select():
ib.emit(tvm.make.Evaluate( ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(i*4+j<n), m, n))) tvm.make.Select(ib.likely(i*4+j<n), m, n)))
stmt = ib.get() stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select)))) assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
...@@ -158,12 +173,13 @@ def test_everything_during_deduction(): ...@@ -158,12 +173,13 @@ def test_everything_during_deduction():
# this guard will produce everything during deduction # this guard will produce everything during deduction
ib.emit(tvm.make.Evaluate(m)) ib.emit(tvm.make.Evaluate(m))
stmt = ib.get() stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse)) assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse))
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_const_loop()
test_multi_loop() test_multi_loop()
test_multi_if() test_multi_if()
test_thread_axis() test_thread_axis()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment