Commit 33910970 by Salem Derisavi Committed by Tianqi Chen

1) Make unroll code reusable 2) reduce non-determinisim in CanonicalSimplify (#701)

* 1) Refactored some parts of the unrolling code into their own methods so we can reuse unrolling functionality in other parts of the code. E.g., to explicitly unroll loops with count of 1 when they are programmatically created.
2) Reorder based on top operator before resorting to pointers, which causes non-determinism.

* Fixed lint errors
parent 154959b1
...@@ -29,6 +29,8 @@ struct ComExprEntry { ...@@ -29,6 +29,8 @@ struct ComExprEntry {
inline bool operator<(const ComExprEntry& other) const { inline bool operator<(const ComExprEntry& other) const {
if (level < other.level) return true; if (level < other.level) return true;
if (level > other.level) return false; if (level > other.level) return false;
if (value.type_index() < other.value.type_index()) return true;
if (value.type_index() > other.value.type_index()) return false;
return value.get() < other.value.get(); return value.get() < other.value.get();
} }
}; };
......
...@@ -30,17 +30,7 @@ class LoopUnroller : public IRMutator { ...@@ -30,17 +30,7 @@ class LoopUnroller : public IRMutator {
Stmt Mutate_(const For* op, const Stmt& s) { Stmt Mutate_(const For* op, const Stmt& s) {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<For>(); op = stmt.as<For>();
// constant folding. int value = GetExtent(op);
Expr extent = ir::Simplify(op->extent);
const IntImm* v1 = extent.as<IntImm>();
const UIntImm* v2 = extent.as<UIntImm>();
int value = -1;
if (v1 != nullptr) {
value = static_cast<int>(v1->value);
}
if (v2 != nullptr) {
value = static_cast<int>(v2->value);
}
// condition for auto unroll // condition for auto unroll
bool auto_unroll = ( bool auto_unroll = (
op->for_type == ForType::Serial && op->for_type == ForType::Serial &&
...@@ -66,24 +56,7 @@ class LoopUnroller : public IRMutator { ...@@ -66,24 +56,7 @@ class LoopUnroller : public IRMutator {
} }
if (auto_unroll && explicit_unroll_) { if (auto_unroll && explicit_unroll_) {
using arith::ComputeExpr; return Unroll(op);
if (value == 0) return Evaluate::make(0);
Stmt body = op->body;
Map<Var, Expr> vmap;
Stmt unrolled;
for (int i = 0; i < value; ++i) {
Var lv(op->loop_var.node_);
vmap.Set(lv,
ComputeExpr<Add>(
op->min, make_const(op->loop_var.type(), i)));
Stmt step = Substitute(body, vmap);
if (unrolled.defined()) {
unrolled = Block::make(unrolled, step);
} else {
unrolled = step;
}
}
return unrolled;
} else { } else {
if (auto_unroll) { if (auto_unroll) {
if (op->for_type != ForType::Unrolled) { if (op->for_type != ForType::Unrolled) {
...@@ -128,7 +101,47 @@ class LoopUnroller : public IRMutator { ...@@ -128,7 +101,47 @@ class LoopUnroller : public IRMutator {
} }
} }
Stmt Unroll(const For* op) {
using arith::ComputeExpr;
int value = GetExtent(op);
// For loop must have a constant integer extent
CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
if (value == 0) return Evaluate::make(0);
Stmt body = op->body;
Map<Var, Expr> vmap;
Stmt unrolled;
for (int i = 0; i < value; ++i) {
Var lv(op->loop_var.node_);
vmap.Set(lv,
ComputeExpr<Add>(
op->min, make_const(op->loop_var.type(), i)));
Stmt step = Substitute(body, vmap);
if (unrolled.defined()) {
unrolled = Block::make(unrolled, step);
} else {
unrolled = step;
}
}
return unrolled;
}
private: private:
// returns the extent of the loop if it's a constant integer, otherwise return -1
int GetExtent(const For* op) {
// constant folding.
Expr extent = ir::Simplify(op->extent);
const IntImm *v1 = extent.as<IntImm>();
const UIntImm *v2 = extent.as<UIntImm>();
int value = -1;
if (v1 != nullptr) {
value = static_cast<int>(v1->value);
}
if (v2 != nullptr) {
value = static_cast<int>(v2->value);
}
return value;
}
// maximum number of step to perform auto unroll. // maximum number of step to perform auto unroll.
int auto_max_step_; int auto_max_step_;
int auto_max_depth_; int auto_max_depth_;
...@@ -162,5 +175,13 @@ Stmt UnrollLoop(Stmt stmt, ...@@ -162,5 +175,13 @@ Stmt UnrollLoop(Stmt stmt,
} }
} }
Stmt UnrollLoopExplicitly(Stmt stmt) {
const For* op = stmt.as<For>();
if (!op) {
LOG(FATAL) << "attempted to unroll a non-loop statement";
}
return LoopUnroller(0, 0, 0, false).Unroll(op);
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment