Commit a2aa154c by Tianqi Chen Committed by GitHub

[UNROLL] New unroll option (#647)

parent c6a1241e
......@@ -205,10 +205,16 @@ Stmt NarrowChannelAccess(Stmt stmt);
* \param stmt The statment to be unrolled.
* \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_min_depth The minimum depth before we can start automatic unroll
* \param auto_max_extent The maximum extent of the loop we can unroll,
* this is an legacy option that donot take the loop total steps into account.
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return Transformed stmt.
*/
Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_min_depth, bool explicit_unroll);
Stmt UnrollLoop(Stmt stmt,
int auto_max_step,
int auto_min_depth,
int auto_max_extent,
bool explicit_unroll);
/*!
* \brief vectorize the constant loops
......
......@@ -28,7 +28,8 @@ class BuildConfig(object):
current = None
defaults = {
"auto_unroll_max_step": 0,
"auto_unroll_max_depth": 4,
"auto_unroll_max_depth": 8,
"auto_unroll_max_extent": 0,
"unroll_explicit": True,
"detect_global_barrier": False,
"offset_factor": 0,
......@@ -227,6 +228,7 @@ def lower(sch,
stmt,
cfg.auto_unroll_max_step,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)
for f in lower_phase2:
stmt = f(stmt)
......
......@@ -91,7 +91,7 @@ REGISTER_PASS4(Inline);
REGISTER_PASS3(StorageFlatten);
REGISTER_PASS4(IRTransform);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop);
REGISTER_PASS5(UnrollLoop);
REGISTER_PASS3(InjectCopyIntrin);
REGISTER_PASS2(ThreadSync);
REGISTER_PASS5(MakeAPI);
......
......@@ -19,9 +19,11 @@ class LoopUnroller : public IRMutator {
public:
explicit LoopUnroller(int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll)
: auto_max_step_(auto_max_step),
auto_max_depth_(auto_max_depth),
auto_max_extent_(auto_max_extent),
explicit_unroll_(explicit_unroll) {
}
......@@ -42,10 +44,13 @@ class LoopUnroller : public IRMutator {
// condition for auto unroll
bool auto_unroll = (
op->for_type == ForType::Serial &&
normal_loop_depth_ == 0 &&
value >= 0 &&
unroll_depth_ <= auto_max_depth_ &&
value * step_count_ <= auto_max_step_);
normal_loop_depth_ == 0 &&
unroll_depth_ <= auto_max_depth_);
auto_unroll = auto_unroll && (
value * step_count_ <= auto_max_step_||
value <= auto_max_extent_);
if (op->for_type == ForType::Unrolled) {
CHECK_GE(value, 0)
......@@ -127,6 +132,9 @@ class LoopUnroller : public IRMutator {
// maximum number of step to perform auto unroll.
int auto_max_step_;
int auto_max_depth_;
// max extent of loop to auto unroll
// this not not count the total steps, only count the number of loops
int auto_max_extent_;
bool explicit_unroll_;
// Number of normal loops in scope
int normal_loop_depth_{0};
......@@ -140,10 +148,12 @@ class LoopUnroller : public IRMutator {
Stmt UnrollLoop(Stmt stmt,
int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll) {
Stmt ret = LoopUnroller(
auto_max_step,
auto_max_depth,
auto_max_extent,
explicit_unroll).Mutate(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
......
......@@ -14,11 +14,11 @@ def test_unroll_loop():
tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1)))
assert isinstance(stmt, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, True)
ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, True)
assert not isinstance(ret, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(stmt, 15, 8, True)
ret = tvm.ir_pass.UnrollLoop(stmt, 15, 8, 0, True)
assert isinstance(ret, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, False)
ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, False)
assert isinstance(ret, tvm.stmt.For)
assert ret.for_type == tvm.stmt.For.Unrolled
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment