Commit b3375702 by Li-Wen Chang Committed by Tianqi Chen

Add functionality to optionally disable Select rewriting (#2385)

parent 5d71e881
......@@ -223,6 +223,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to instrument loads and stores with check for out of the bounds. */
bool instrument_bound_checkers = false;
/*! \brief Whether to disable select rewriting. */
bool disable_select_rewriting = false;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
......@@ -236,6 +239,7 @@ class BuildConfigNode : public Node {
v->Visit("partition_const_loop", &partition_const_loop);
v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting);
}
static constexpr const char* _type_key = "BuildConfig";
......
......@@ -126,7 +126,8 @@ class BuildConfig(NodeBase):
"restricted_func": True,
"double_buffer_split_loop": 1,
"dump_pass_ir": False,
"instrument_bound_checkers": False
"instrument_bound_checkers": False,
"disable_select_rewriting": False
}
_dump_ir = DumpIR()
......@@ -368,6 +369,7 @@ def lower(sch,
stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase3:
stmt = f(stmt)
......
......@@ -381,6 +381,8 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::Simplify(stmt);
stmt = ir::LowerStorageAccessInfo(stmt);
stmt = ir::RemoveNoOp(stmt);
if (!(config->disable_select_rewriting))
stmt = ir::RewriteUnsafeSelect(stmt);
if (config->instrument_bound_checkers)
......@@ -534,7 +536,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "restricted_func=" << op->restricted_func << ", ";
p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
p->stream << "dump_pass_ir=" << op->dump_pass_ir;
p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << ")";
});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment