Commit 4b1d3d87 by Logan Weber Committed by Tianqi Chen

Add `SkipVectorize` pass (#3222)

parent 3d1d17e3
...@@ -61,6 +61,7 @@ tvm.ir_pass ...@@ -61,6 +61,7 @@ tvm.ir_pass
tvm.ir_pass.CanonicalSimplify tvm.ir_pass.CanonicalSimplify
tvm.ir_pass.StorageFlatten tvm.ir_pass.StorageFlatten
tvm.ir_pass.VectorizeLoop tvm.ir_pass.VectorizeLoop
tvm.ir_pass.SkipVectorize
tvm.ir_pass.UnrollLoop tvm.ir_pass.UnrollLoop
tvm.ir_pass.ThreadSync tvm.ir_pass.ThreadSync
tvm.ir_pass.StorageRewrite tvm.ir_pass.StorageRewrite
......
...@@ -246,6 +246,9 @@ class BuildConfigNode : public Node { ...@@ -246,6 +246,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable select rewriting. */ /*! \brief Whether to disable select rewriting. */
bool disable_select_rewriting = false; bool disable_select_rewriting = false;
/*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment); v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor); v->Visit("offset_factor", &offset_factor);
...@@ -260,6 +263,7 @@ class BuildConfigNode : public Node { ...@@ -260,6 +263,7 @@ class BuildConfigNode : public Node {
v->Visit("dump_pass_ir", &dump_pass_ir); v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers); v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting); v->Visit("disable_select_rewriting", &disable_select_rewriting);
v->Visit("disable_vectorize", &disable_vectorize);
} }
static constexpr const char* _type_key = "BuildConfig"; static constexpr const char* _type_key = "BuildConfig";
......
...@@ -250,35 +250,42 @@ Stmt UnrollLoop(Stmt stmt, ...@@ -250,35 +250,42 @@ Stmt UnrollLoop(Stmt stmt,
/*! /*!
* \brief vectorize the constant loops * \brief vectorize the constant loops
* \param stmt The statment to be vectorized. * \param stmt The statement to be vectorized.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt VectorizeLoop(Stmt stmt); Stmt VectorizeLoop(Stmt stmt);
/*! /*!
* \brief convert vectorized loops into serialized loops
* \param stmt The statement to skip vectorization on.
* \return Transformed stmt.
*/
Stmt SkipVectorize(Stmt stmt);
/*!
* \brief instruments bound checkers. * \brief instruments bound checkers.
* \param stmt The statment to be instrumented. * \param stmt The statement to be instrumented.
* \return Instrumented Stmt. * \return Instrumented stmt.
*/ */
Stmt InstrumentBoundCheckers(Stmt stmt); Stmt InstrumentBoundCheckers(Stmt stmt);
/*! /*!
* \brief Inject virtual thread loops into stmt. * \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed. * \param stmt The statement to be transformed.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt InjectVirtualThread(Stmt stmt); Stmt InjectVirtualThread(Stmt stmt);
/*! /*!
* \brief Inject prefetch instructions into stmt. * \brief Inject prefetch instructions into stmt.
* \param stmt The statment to be transformed. * \param stmt The statement to be transformed.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt InjectPrefetch(Stmt stmt); Stmt InjectPrefetch(Stmt stmt);
/*! /*!
* \brief Inject double buffer into stmt. * \brief Inject double buffer into stmt.
* \param stmt The statment to be transformed. * \param stmt The statement to be transformed.
* \param split_loop Loop splitting factor. * \param split_loop Loop splitting factor.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
...@@ -287,7 +294,7 @@ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); ...@@ -287,7 +294,7 @@ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
/*! /*!
* \brief Inject copy intrinsics with optional pad. * \brief Inject copy intrinsics with optional pad.
* *
* \param stmt The statment to be transformed. * \param stmt The statement to be transformed.
* \param pragma_key The pragma key for hint of copy. * \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature * \param fintrin The function with signature
* *
...@@ -308,7 +315,7 @@ Stmt InjectCopyIntrin(Stmt stmt, ...@@ -308,7 +315,7 @@ Stmt InjectCopyIntrin(Stmt stmt,
* Trying to share space between allocations to make * Trying to share space between allocations to make
* a static allocation plan when possible. * a static allocation plan when possible.
* *
* \param stmt The stmt to be trasnformed * \param stmt The stmt to be transformed
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt StorageRewrite(Stmt stmt); Stmt StorageRewrite(Stmt stmt);
...@@ -324,7 +331,7 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop); ...@@ -324,7 +331,7 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop);
/*! /*!
* \brief Detect and insert sync points to co-processor. * \brief Detect and insert sync points to co-processor.
* *
* \param stmt The stmt to be trasnformed * \param stmt The stmt to be transformed
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt CoProcSync(Stmt stmt); Stmt CoProcSync(Stmt stmt);
...@@ -332,7 +339,7 @@ Stmt CoProcSync(Stmt stmt); ...@@ -332,7 +339,7 @@ Stmt CoProcSync(Stmt stmt);
/*! /*!
* \brief Lift common attrs with attr_key to outer scope. * \brief Lift common attrs with attr_key to outer scope.
* *
* \param stmt The stmt to be trasnformed * \param stmt The stmt to be transformed
* \param attr_key The attribute key to be checked. * \param attr_key The attribute key to be checked.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
...@@ -340,7 +347,7 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key); ...@@ -340,7 +347,7 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
/*! /*!
* \brief Detect and rewrite unsafe select that contains memory access. * \brief Detect and rewrite unsafe select that contains memory access.
* \param stmt The statment to be rewritten. * \param stmt The statement to be rewritten.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt RewriteUnsafeSelect(Stmt stmt); Stmt RewriteUnsafeSelect(Stmt stmt);
...@@ -349,7 +356,7 @@ Stmt RewriteUnsafeSelect(Stmt stmt); ...@@ -349,7 +356,7 @@ Stmt RewriteUnsafeSelect(Stmt stmt);
* \brief Lower attached storage access information. * \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish. * Do this pass after all storage access analysis finish.
* *
* \param stmt The stmt to be trasnformed * \param stmt The stmt to be transformed
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt LowerStorageAccessInfo(Stmt stmt); Stmt LowerStorageAccessInfo(Stmt stmt);
...@@ -358,7 +365,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt); ...@@ -358,7 +365,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt);
* \brief Decorate the stmt with a device scope, this is helpful for * \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks. * hardware accelerator without thread blocks.
* *
* \param stmt The stmt to be trasnformed * \param stmt The stmt to be transformed
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt DecorateDeviceScope(Stmt stmt); Stmt DecorateDeviceScope(Stmt stmt);
...@@ -381,7 +388,7 @@ Stmt DecorateDeviceScope(Stmt stmt); ...@@ -381,7 +388,7 @@ Stmt DecorateDeviceScope(Stmt stmt);
* \return a LoweredFunc with the specified signiture. * \return a LoweredFunc with the specified signiture.
* *
* \note * \note
* The function signiture have two cases * The function signature have two cases
* *
* let num_packed_args = len(api_args) - num_unpacked_args; * let num_packed_args = len(api_args) - num_unpacked_args;
* *
......
...@@ -143,7 +143,8 @@ class BuildConfig(NodeBase): ...@@ -143,7 +143,8 @@ class BuildConfig(NodeBase):
"double_buffer_split_loop": 1, "double_buffer_split_loop": 1,
"dump_pass_ir": False, "dump_pass_ir": False,
"instrument_bound_checkers": False, "instrument_bound_checkers": False,
"disable_select_rewriting": False "disable_select_rewriting": False,
"disable_vectorize": False
} }
_dump_ir = DumpIR() _dump_ir = DumpIR()
...@@ -384,7 +385,10 @@ def lower(sch, ...@@ -384,7 +385,10 @@ def lower(sch,
# Phase 2 # Phase 2
if not simple_mode: if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
stmt = ir_pass.VectorizeLoop(stmt) if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt)
else:
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt) stmt = ir_pass.StorageRewrite(stmt)
......
...@@ -392,7 +392,11 @@ Stmt BuildStmt(Schedule sch, ...@@ -392,7 +392,11 @@ Stmt BuildStmt(Schedule sch,
if (loop_partition) { if (loop_partition) {
stmt = ir::LoopPartition(stmt, config->partition_const_loop); stmt = ir::LoopPartition(stmt, config->partition_const_loop);
} }
stmt = ir::VectorizeLoop(stmt); if (config->disable_vectorize) {
stmt = ir::SkipVectorize(stmt);
} else {
stmt = ir::VectorizeLoop(stmt);
}
stmt = ir::InjectVirtualThread(stmt); stmt = ir::InjectVirtualThread(stmt);
stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop); stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
stmt = ir::StorageRewrite(stmt); stmt = ir::StorageRewrite(stmt);
...@@ -642,6 +646,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -642,6 +646,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", "; p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << "disable_vectorize=" << op->disable_vectorize;
p->stream << ")"; p->stream << ")";
}); });
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -519,5 +519,23 @@ Stmt VectorizeLoop(Stmt stmt) { ...@@ -519,5 +519,23 @@ Stmt VectorizeLoop(Stmt stmt) {
return LoopVectorizer().Mutate(stmt); return LoopVectorizer().Mutate(stmt);
} }
class VectorizeSkipper : public IRMutator {
public:
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<For>();
if (op->for_type == ForType::Vectorized) {
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
op->body);
} else {
return stmt;
}
}
};
Stmt SkipVectorize(Stmt stmt) {
return VectorizeSkipper().Mutate(stmt);
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment