Commit b3375702 by Li-Wen Chang Committed by Tianqi Chen

Add functionality to optionally disable Select rewriting (#2385)

parent 5d71e881
...@@ -223,6 +223,9 @@ class BuildConfigNode : public Node { ...@@ -223,6 +223,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to instrument loads and stores with check for out of the bounds. */ /*! \brief Whether to instrument loads and stores with check for out of the bounds. */
bool instrument_bound_checkers = false; bool instrument_bound_checkers = false;
/*! \brief Whether to disable select rewriting. */
bool disable_select_rewriting = false;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment); v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor); v->Visit("offset_factor", &offset_factor);
...@@ -236,6 +239,7 @@ class BuildConfigNode : public Node { ...@@ -236,6 +239,7 @@ class BuildConfigNode : public Node {
v->Visit("partition_const_loop", &partition_const_loop); v->Visit("partition_const_loop", &partition_const_loop);
v->Visit("dump_pass_ir", &dump_pass_ir); v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers); v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting);
} }
static constexpr const char* _type_key = "BuildConfig"; static constexpr const char* _type_key = "BuildConfig";
......
...@@ -126,7 +126,8 @@ class BuildConfig(NodeBase): ...@@ -126,7 +126,8 @@ class BuildConfig(NodeBase):
"restricted_func": True, "restricted_func": True,
"double_buffer_split_loop": 1, "double_buffer_split_loop": 1,
"dump_pass_ir": False, "dump_pass_ir": False,
"instrument_bound_checkers": False "instrument_bound_checkers": False,
"disable_select_rewriting": False
} }
_dump_ir = DumpIR() _dump_ir = DumpIR()
...@@ -368,7 +369,8 @@ def lower(sch, ...@@ -368,7 +369,8 @@ def lower(sch,
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt) stmt = ir_pass.RemoveNoOp(stmt)
stmt = ir_pass.RewriteUnsafeSelect(stmt) if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase3: for f in lower_phase3:
stmt = f(stmt) stmt = f(stmt)
# Instrument BoundCheckers # Instrument BoundCheckers
......
...@@ -381,7 +381,9 @@ Stmt BuildStmt(Schedule sch, ...@@ -381,7 +381,9 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::Simplify(stmt); stmt = ir::Simplify(stmt);
stmt = ir::LowerStorageAccessInfo(stmt); stmt = ir::LowerStorageAccessInfo(stmt);
stmt = ir::RemoveNoOp(stmt); stmt = ir::RemoveNoOp(stmt);
stmt = ir::RewriteUnsafeSelect(stmt);
if (!(config->disable_select_rewriting))
stmt = ir::RewriteUnsafeSelect(stmt);
if (config->instrument_bound_checkers) if (config->instrument_bound_checkers)
stmt = ir::InstrumentBoundCheckers(stmt); stmt = ir::InstrumentBoundCheckers(stmt);
...@@ -534,7 +536,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -534,7 +536,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "restricted_func=" << op->restricted_func << ", "; p->stream << "restricted_func=" << op->restricted_func << ", ";
p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", "; p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
p->stream << "partition_const_loop=" << op->partition_const_loop << ", "; p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
p->stream << "dump_pass_ir=" << op->dump_pass_ir; p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << ")"; p->stream << ")";
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment