Commit 4b1d3d87 by Logan Weber Committed by Tianqi Chen

Add `SkipVectorize` pass (#3222)

parent 3d1d17e3
......@@ -61,6 +61,7 @@ tvm.ir_pass
tvm.ir_pass.CanonicalSimplify
tvm.ir_pass.StorageFlatten
tvm.ir_pass.VectorizeLoop
tvm.ir_pass.SkipVectorize
tvm.ir_pass.UnrollLoop
tvm.ir_pass.ThreadSync
tvm.ir_pass.StorageRewrite
......
......@@ -246,6 +246,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable select rewriting. */
bool disable_select_rewriting = false;
/*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
......@@ -260,6 +263,7 @@ class BuildConfigNode : public Node {
v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting);
v->Visit("disable_vectorize", &disable_vectorize);
}
static constexpr const char* _type_key = "BuildConfig";
......
......@@ -250,35 +250,42 @@ Stmt UnrollLoop(Stmt stmt,
/*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
* \param stmt The statement to be vectorized.
* \return Transformed stmt.
*/
Stmt VectorizeLoop(Stmt stmt);
/*!
* \brief convert vectorized loops into serialized loops
* \param stmt The statement to skip vectorization on.
* \return Transformed stmt.
*/
Stmt SkipVectorize(Stmt stmt);
/*!
* \brief instruments bound checkers.
* \param stmt The statment to be instrumented.
* \return Instrumented Stmt.
* \param stmt The statement to be instrumented.
* \return Instrumented stmt.
*/
Stmt InstrumentBoundCheckers(Stmt stmt);
/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectVirtualThread(Stmt stmt);
/*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectPrefetch(Stmt stmt);
/*!
* \brief Inject double buffer into stmt.
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \param split_loop Loop splitting factor.
* \return Transformed stmt.
*/
......@@ -287,7 +294,7 @@ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
/*!
* \brief Inject copy intrinsics with optional pad.
*
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
......@@ -308,7 +315,7 @@ Stmt InjectCopyIntrin(Stmt stmt,
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt StorageRewrite(Stmt stmt);
......@@ -324,7 +331,7 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop);
/*!
* \brief Detect and insert sync points to co-processor.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt CoProcSync(Stmt stmt);
......@@ -332,7 +339,7 @@ Stmt CoProcSync(Stmt stmt);
/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \param attr_key The attribute key to be checked.
* \return Transformed stmt.
*/
......@@ -340,7 +347,7 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
/*!
* \brief Detect and rewrite unsafe select that contains memory access.
* \param stmt The statment to be rewritten.
* \param stmt The statement to be rewritten.
* \return Transformed stmt.
*/
Stmt RewriteUnsafeSelect(Stmt stmt);
......@@ -349,7 +356,7 @@ Stmt RewriteUnsafeSelect(Stmt stmt);
* \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt LowerStorageAccessInfo(Stmt stmt);
......@@ -358,7 +365,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt);
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt DecorateDeviceScope(Stmt stmt);
......@@ -381,7 +388,7 @@ Stmt DecorateDeviceScope(Stmt stmt);
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
* The function signature have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
......
......@@ -143,7 +143,8 @@ class BuildConfig(NodeBase):
"double_buffer_split_loop": 1,
"dump_pass_ir": False,
"instrument_bound_checkers": False,
"disable_select_rewriting": False
"disable_select_rewriting": False,
"disable_vectorize": False
}
_dump_ir = DumpIR()
......@@ -384,7 +385,10 @@ def lower(sch,
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
stmt = ir_pass.VectorizeLoop(stmt)
if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt)
else:
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt)
......
......@@ -392,7 +392,11 @@ Stmt BuildStmt(Schedule sch,
if (loop_partition) {
stmt = ir::LoopPartition(stmt, config->partition_const_loop);
}
stmt = ir::VectorizeLoop(stmt);
if (config->disable_vectorize) {
stmt = ir::SkipVectorize(stmt);
} else {
stmt = ir::VectorizeLoop(stmt);
}
stmt = ir::InjectVirtualThread(stmt);
stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
stmt = ir::StorageRewrite(stmt);
......@@ -642,6 +646,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << "disable_vectorize=" << op->disable_vectorize;
p->stream << ")";
});
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -519,5 +519,23 @@ Stmt VectorizeLoop(Stmt stmt) {
return LoopVectorizer().Mutate(stmt);
}
class VectorizeSkipper : public IRMutator {
public:
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<For>();
if (op->for_type == ForType::Vectorized) {
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
op->body);
} else {
return stmt;
}
}
};
Stmt SkipVectorize(Stmt stmt) {
return VectorizeSkipper().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment