Commit 5f1816db by xqdan Committed by Tianqi Chen

enable partition const loop with build flag (#732)

* [SCHEDULE]enable partition const loop with build flag (#719)

    * enable partition loop with build flag

    * add a testcase, and modify LoopPartition related cases

*     * add document for split_const_loop
parent 66fa0c3d
......@@ -116,6 +116,9 @@ struct BuildConfig {
/*! \brief Whether to detect global barrier */
bool detect_global_barrier = false;
/*! \brief Whether to partition const loop */
bool partition_const_loop = false;
BuildConfig() {
}
};
......
......@@ -289,9 +289,10 @@ Stmt StorageRewrite(Stmt stmt);
/*!
* \brief partition loops in the stmt
* \param stmt The stmt to do loop partition
* \param split_const_loop flag to enable partition for const loop
* \return Transformed stmt.
*/
Stmt LoopPartition(Stmt stmt);
Stmt LoopPartition(Stmt stmt, bool split_const_loop);
/*!
* \brief Detect and insert sync points to co-processor.
......
......@@ -32,6 +32,7 @@ class BuildConfig(object):
"auto_unroll_max_extent": 0,
"unroll_explicit": True,
"detect_global_barrier": False,
"partition_const_loop": False,
"offset_factor": 0,
"data_alignment": -1,
"restricted_func": True,
......@@ -88,6 +89,9 @@ def build_config(**kwargs):
detect_global_barrier: bool, default=True
Whether detect global barrier.
partition_const_loop: bool, default=False
Whether partition const loop
data_alignment: int, optional
The alignment of data pointer in bytes.
If -1 is passed, the alignment will be set to TVM's internal default.
......@@ -219,7 +223,7 @@ def lower(sch,
stmt = f(stmt)
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
......
......@@ -119,7 +119,7 @@ REGISTER_PASS1(LowerStorageAccessInfo);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS2(InjectDoubleBuffer);
REGISTER_PASS1(LoopPartition);
REGISTER_PASS2(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LiftAttrScope);
......
......@@ -208,7 +208,7 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::StorageFlatten(stmt, out_binds, 64);
stmt = ir::CanonicalSimplify(stmt);
if (loop_partition) {
stmt = ir::LoopPartition(stmt);
stmt = ir::LoopPartition(stmt, config.partition_const_loop);
}
stmt = ir::VectorizeLoop(stmt);
stmt = ir::InjectVirtualThread(stmt);
......
......@@ -45,10 +45,12 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
class CandidateSelector final : public IRVisitor {
public:
using VarIsUsed = bool;
CandidateSelector() {}
explicit CandidateSelector(bool split_const_loop)
: split_const_loop_(split_const_loop) {}
void Visit_(const For* op) {
if (!is_const(op->min) || !is_const(op->extent)) {
// partition const loop when sets split_const_loop_
if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) {
const Variable* var = op->loop_var.get();
record_.insert({var, false});
IRVisitor::Visit_(op);
......@@ -67,7 +69,7 @@ class CandidateSelector final : public IRVisitor {
CHECK(iv);
Var var = iv->var;
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
if ((scope.rank == 0) && !is_const(op->value)) {
if ((scope.rank == 0) && (!is_const(op->value) || split_const_loop_)) {
record_.insert({var.get(), false});
IRVisitor::Visit_(op);
if (record_.at(var.get()) && !no_split_) {
......@@ -115,6 +117,7 @@ class CandidateSelector final : public IRVisitor {
private:
bool in_likely_{false};
bool no_split_{false};
bool split_const_loop_{false};
std::unordered_map<const Variable*, VarIsUsed> record_;
};
......@@ -392,8 +395,8 @@ class RemoveLikelyTags : public IRMutator {
}
};
Stmt LoopPartition(Stmt stmt) {
CandidateSelector selector;
Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
CandidateSelector selector(split_const_loop);
selector.Visit(stmt);
stmt = LoopPartitioner(selector.candidates).Mutate(stmt);
stmt = RemoveLikelyTags().Mutate(stmt);
......
......@@ -27,7 +27,7 @@ def test_add_pipeline():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Db = tvm.decl_buffer(D.shape, D.dtype, name='D')
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
......
......@@ -19,7 +19,7 @@ def lower(sch, args):
sch = sch.normalize()
bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
......@@ -37,7 +37,22 @@ def test_basic():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
def test_const_loop():
n = 21
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')
T = tvm.compute((n, ), lambda i: A[i]+B[i])
s = tvm.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
......@@ -53,7 +68,7 @@ def test_multi_loop():
with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
......@@ -73,7 +88,7 @@ def test_multi_if():
with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.first))
......@@ -92,7 +107,7 @@ def test_thread_axis():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
......@@ -127,7 +142,7 @@ def test_select():
ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(i*4+j<n), m, n)))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
......@@ -158,12 +173,13 @@ def test_everything_during_deduction():
# this guard will produce everything during deduction
ib.emit(tvm.make.Evaluate(m))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse))
if __name__ == "__main__":
test_basic()
test_const_loop()
test_multi_loop()
test_multi_if()
test_thread_axis()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment