Commit a2aa154c by Tianqi Chen Committed by GitHub

[UNROLL] New unroll option (#647)

parent c6a1241e
...@@ -205,10 +205,16 @@ Stmt NarrowChannelAccess(Stmt stmt); ...@@ -205,10 +205,16 @@ Stmt NarrowChannelAccess(Stmt stmt);
* \param stmt The statment to be unrolled. * \param stmt The statment to be unrolled.
* \param auto_max_step The maximum step before stop attach automatic unroll * \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_min_depth The minimum depth before we can start automatic unroll * \param auto_min_depth The minimum depth before we can start automatic unroll
* \param auto_max_extent The maximum extent of the loop we can unroll,
* this is an legacy option that donot take the loop total steps into account.
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen. * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_min_depth, bool explicit_unroll); Stmt UnrollLoop(Stmt stmt,
int auto_max_step,
int auto_min_depth,
int auto_max_extent,
bool explicit_unroll);
/*! /*!
* \brief vectorize the constant loops * \brief vectorize the constant loops
......
...@@ -28,7 +28,8 @@ class BuildConfig(object): ...@@ -28,7 +28,8 @@ class BuildConfig(object):
current = None current = None
defaults = { defaults = {
"auto_unroll_max_step": 0, "auto_unroll_max_step": 0,
"auto_unroll_max_depth": 4, "auto_unroll_max_depth": 8,
"auto_unroll_max_extent": 0,
"unroll_explicit": True, "unroll_explicit": True,
"detect_global_barrier": False, "detect_global_barrier": False,
"offset_factor": 0, "offset_factor": 0,
...@@ -227,6 +228,7 @@ def lower(sch, ...@@ -227,6 +228,7 @@ def lower(sch,
stmt, stmt,
cfg.auto_unroll_max_step, cfg.auto_unroll_max_step,
cfg.auto_unroll_max_depth, cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit) cfg.unroll_explicit)
for f in lower_phase2: for f in lower_phase2:
stmt = f(stmt) stmt = f(stmt)
......
...@@ -91,7 +91,7 @@ REGISTER_PASS4(Inline); ...@@ -91,7 +91,7 @@ REGISTER_PASS4(Inline);
REGISTER_PASS3(StorageFlatten); REGISTER_PASS3(StorageFlatten);
REGISTER_PASS4(IRTransform); REGISTER_PASS4(IRTransform);
REGISTER_PASS1(VectorizeLoop); REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop); REGISTER_PASS5(UnrollLoop);
REGISTER_PASS3(InjectCopyIntrin); REGISTER_PASS3(InjectCopyIntrin);
REGISTER_PASS2(ThreadSync); REGISTER_PASS2(ThreadSync);
REGISTER_PASS5(MakeAPI); REGISTER_PASS5(MakeAPI);
......
...@@ -19,9 +19,11 @@ class LoopUnroller : public IRMutator { ...@@ -19,9 +19,11 @@ class LoopUnroller : public IRMutator {
public: public:
explicit LoopUnroller(int auto_max_step, explicit LoopUnroller(int auto_max_step,
int auto_max_depth, int auto_max_depth,
int auto_max_extent,
bool explicit_unroll) bool explicit_unroll)
: auto_max_step_(auto_max_step), : auto_max_step_(auto_max_step),
auto_max_depth_(auto_max_depth), auto_max_depth_(auto_max_depth),
auto_max_extent_(auto_max_extent),
explicit_unroll_(explicit_unroll) { explicit_unroll_(explicit_unroll) {
} }
...@@ -42,10 +44,13 @@ class LoopUnroller : public IRMutator { ...@@ -42,10 +44,13 @@ class LoopUnroller : public IRMutator {
// condition for auto unroll // condition for auto unroll
bool auto_unroll = ( bool auto_unroll = (
op->for_type == ForType::Serial && op->for_type == ForType::Serial &&
normal_loop_depth_ == 0 &&
value >= 0 && value >= 0 &&
unroll_depth_ <= auto_max_depth_ && normal_loop_depth_ == 0 &&
value * step_count_ <= auto_max_step_); unroll_depth_ <= auto_max_depth_);
auto_unroll = auto_unroll && (
value * step_count_ <= auto_max_step_||
value <= auto_max_extent_);
if (op->for_type == ForType::Unrolled) { if (op->for_type == ForType::Unrolled) {
CHECK_GE(value, 0) CHECK_GE(value, 0)
...@@ -127,6 +132,9 @@ class LoopUnroller : public IRMutator { ...@@ -127,6 +132,9 @@ class LoopUnroller : public IRMutator {
// maximum number of step to perform auto unroll. // maximum number of step to perform auto unroll.
int auto_max_step_; int auto_max_step_;
int auto_max_depth_; int auto_max_depth_;
// max extent of loop to auto unroll
// this not not count the total steps, only count the number of loops
int auto_max_extent_;
bool explicit_unroll_; bool explicit_unroll_;
// Number of normal loops in scope // Number of normal loops in scope
int normal_loop_depth_{0}; int normal_loop_depth_{0};
...@@ -140,10 +148,12 @@ class LoopUnroller : public IRMutator { ...@@ -140,10 +148,12 @@ class LoopUnroller : public IRMutator {
Stmt UnrollLoop(Stmt stmt, Stmt UnrollLoop(Stmt stmt,
int auto_max_step, int auto_max_step,
int auto_max_depth, int auto_max_depth,
int auto_max_extent,
bool explicit_unroll) { bool explicit_unroll) {
Stmt ret = LoopUnroller( Stmt ret = LoopUnroller(
auto_max_step, auto_max_step,
auto_max_depth, auto_max_depth,
auto_max_extent,
explicit_unroll).Mutate(stmt); explicit_unroll).Mutate(stmt);
if (!ret.same_as(stmt)) { if (!ret.same_as(stmt)) {
return ConvertSSA(ret); return ConvertSSA(ret);
......
...@@ -14,11 +14,11 @@ def test_unroll_loop(): ...@@ -14,11 +14,11 @@ def test_unroll_loop():
tvm.make.Load(dtype, Ab.data, i) + 1, tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1))) j + 1)))
assert isinstance(stmt, tvm.stmt.For) assert isinstance(stmt, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, True) ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, True)
assert not isinstance(ret, tvm.stmt.For) assert not isinstance(ret, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(stmt, 15, 8, True) ret = tvm.ir_pass.UnrollLoop(stmt, 15, 8, 0, True)
assert isinstance(ret, tvm.stmt.For) assert isinstance(ret, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, False) ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, False)
assert isinstance(ret, tvm.stmt.For) assert isinstance(ret, tvm.stmt.For)
assert ret.for_type == tvm.stmt.For.Unrolled assert ret.for_type == tvm.stmt.For.Unrolled
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment