Unverified Commit 32648950 by Tianqi Chen Committed by GitHub

[TIR][REFACTOR] Migrate low-level passes in tvm.lower to the Unified IR pass manager. (#5364)

- Migrate BoundCheckers and Simplify
- Migrate RewriteUnsafeSelect and RemoveNoOp
- Migrate UnrollLoop and StorageRewrite
- Migrate InjectDoubleBuffer and InjectVirtualThread
- Migrate LoopPartition and Vectorize
- Migrate CoProcSync, LiftAttrScope, InjectCopyIntrin

We still keep ir_pass registerations for now.
Need a separate PR to refactor the parts before the StorageFlatten.
parent fbcf61ab
......@@ -53,7 +53,6 @@ struct ExprDeepEqual {
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};
/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
......
......@@ -203,59 +203,6 @@ Stmt RewriteForTensorCore(Stmt stmt,
bool VerifyCompactBuffer(Stmt stmt);
/*!
* \brief Remove No Op from the Stmt.
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt RemoveNoOp(Stmt stmt);
/*!
* \brief unroll the constant loop marked by unroll.
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
*
* \param stmt The statment to be unrolled.
* \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_max_depth The maximum depth before stop attach automatic unroll
* \param auto_max_extent The maximum extent of the loop we can unroll,
* this is an legacy option that do not take the loop total steps into account.
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return Transformed stmt.
*/
Stmt UnrollLoop(Stmt stmt,
int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll);
/*!
* \brief vectorize the constant loops
* \param stmt The statement to be vectorized.
* \return Transformed stmt.
*/
Stmt VectorizeLoop(Stmt stmt);
/*!
* \brief convert vectorized loops into serialized loops
* \param stmt The statement to skip vectorization on.
* \return Transformed stmt.
*/
Stmt SkipVectorize(Stmt stmt);
/*!
* \brief instruments bound checkers.
* \param stmt The statement to be instrumented.
* \return Instrumented stmt.
*/
Stmt InstrumentBoundCheckers(Stmt stmt);
/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectVirtualThread(Stmt stmt);
/*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
......@@ -263,84 +210,6 @@ Stmt InjectVirtualThread(Stmt stmt);
Stmt InjectPrefetch(Stmt stmt);
/*!
* \brief Inject double buffer into stmt.
* \param stmt The statement to be transformed.
* \param split_loop Loop splitting factor.
* \return Transformed stmt.
*/
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
/*!
* \brief Inject copy intrinsics with optional pad.
*
* \param stmt The statement to be transformed.
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
* Stmt fintrin(Buffer src,
* Buffer dst,
* Array<Expr> pad_before,
* Array<Expr> pad_after,
* Expr pad_value)
* \return Transformed stmt.
*/
Stmt InjectCopyIntrin(Stmt stmt,
const std::string& pragma_key,
const runtime::PackedFunc& fintrin);
/*!
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt StorageRewrite(Stmt stmt);
/*!
* \brief partition loops in the stmt
* \param stmt The stmt to do loop partition
* \param split_const_loop flag to enable partition for const loop
* \return Transformed stmt.
*/
Stmt LoopPartition(Stmt stmt, bool split_const_loop);
/*!
* \brief Detect and insert sync points to co-processor.
*
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt CoProcSync(Stmt stmt);
/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param stmt The stmt to be transformed
* \param attr_key The attribute key to be checked.
* \return Transformed stmt.
*/
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
/*!
* \brief Detect and rewrite unsafe select that contains memory access.
* \param stmt The statement to be rewritten.
* \return Transformed stmt.
*/
Stmt RewriteUnsafeSelect(Stmt stmt);
/*!
* \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish.
*
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt LowerStorageAccessInfo(Stmt stmt);
/*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
......@@ -357,15 +226,6 @@ Stmt DecorateDeviceScope(Stmt stmt);
Stmt HoistIfThenElse(Stmt stmt);
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
* \note Run this pass after StorageFlatten.
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
*/
Stmt NarrowDataType(Stmt stmt, int target_bits);
/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
......
......@@ -59,6 +59,124 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<runtime::String>& required);
/*!
* \brief Inject copy intrinsics with optional pad.
*
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
* Stmt fintrin(Buffer src,
* Buffer dst,
* Array<Expr> pad_before,
* Array<Expr> pad_after,
* Expr pad_value)
* \return The pass.
*/
TVM_DLL Pass InjectCopyIntrin(std::string pragma_key,
runtime::PackedFunc fintrin);
/*!
* \brief Detect and insert sync points to co-processor.
*
* \return The pass.
*/
TVM_DLL Pass CoProcSync();
/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param attr_key The attribute key to be checked.
* \return The pass.
*/
TVM_DLL Pass LiftAttrScope(std::string attr_key);
/*!
* \brief partition loops in the stmt.
*
* \param split_const_loop flag to enable partition for const loop
*
* \return The pass.
*/
TVM_DLL Pass LoopPartition(bool split_const_loop);
/*!
* \brief Lower vectorization loops.
*
* \param enable_vectorize Whether vectorization is enabled.
*
* \return The pass.
*/
TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
/*!
* \brief Inject virtual thread loops.
*
* \return The pass.
*/
TVM_DLL Pass InjectVirtualThread();
/*!
* \brief Inject double buffer statements.
*
* \param split_loop_factor Loop splitting factor.
* \return The pass.
*/
TVM_DLL Pass InjectDoubleBuffer(int split_loop_factor);
/*!
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \return The pass.
*/
TVM_DLL Pass StorageRewrite();
/*!
* \brief unroll the constant loop marked by unroll.
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
*
* \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_max_depth The maximum depth before stop attach automatic unroll
* \param auto_max_extent The maximum extent of the loop we can unroll,
* this is an legacy option that do not take the loop total steps into account.
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return The pass.
*/
TVM_DLL Pass UnrollLoop(int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll);
/*!
* \brief Remove No Op from the Stmt.
*
* \return The pass.
*/
TVM_DLL Pass RemoveNoOp();
/*!
* \brief Detect and rewrite unsafe select that contains memory access.
*
* \return The pass.
*/
TVM_DLL Pass RewriteUnsafeSelect();
/*!
* \brief Run arithmetic simplifications on the statements and expressions.
*
* \return The pass.
*/
TVM_DLL Pass Simplify();
/*!
* \brief Instruments bound checkers.
*
* \return The pass.
*/
TVM_DLL Pass InstrumentBoundCheckers();
/*!
* \brief Transform the high-level PrimFunc to a low-level version
* that can be used as an API function.
*
......
......@@ -179,6 +179,7 @@ def lower(sch,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)
for f in lower_phase2:
stmt = f(stmt)
......@@ -187,11 +188,14 @@ def lower(sch,
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase3:
stmt = f(stmt)
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
if simple_mode:
return stmt
......
......@@ -60,6 +60,203 @@ def Filter(fcond):
return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
def InjectCopyIntrin(pragma_key, fintrin):
"""Inject virtual thread loops.
Parameters
----------
pragma_key : str
The pragma key for hint of copy.
fintrin : function
The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value)
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectCopyIntrin(pragma_key, fintrin)
def CoProcSync():
"""Detect and insert sync points to co-processor.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.CoProcSync()
def LiftAttrScope(attr_key):
"""Lift common attrs with attr_key to outer scope.
Parameters
----------
attr_key : str
The attribute key to be checked.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LiftAttrScope(attr_key)
def LoopPartition(split_const_loop):
"""Inject virtual thread loops.
Parameters
----------
split_const_loop : bool
Flag to enable partition for const loop.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LoopPartition(split_const_loop)
def VectorizeLoop(enable_vectorize=True):
"""Lower vectorization loops.
Parameters
----------
enable_vectorize : bool
Whether vectorization is enabled.
Will lower to scalar loop when it is turned off.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.VectorizeLoop(enable_vectorize)
def InjectVirtualThread():
"""Inject virtual thread loops.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectVirtualThread()
def InjectDoubleBuffer(split_loop_factor):
"""Inject double buffer statements.
Parameters
----------
split_loop_factor : int
Loop splitting factor.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectDoubleBuffer(split_loop_factor)
def StorageRewrite():
"""Rewrite storage allocation pattern.
Moves the allocation to outer most possible scope.
Trying to share space between allocations to make
a static allocation plan when possible.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.StorageRewrite()
def UnrollLoop(auto_max_step,
auto_max_depth,
auto_max_extent,
explicit_unroll):
"""Unroll the constant loop marked by unroll.
This pass also automatically attach pragma unroll tag to loops which meets the standard.
Parameters
----------
auto_max_step : int
The maximum step before stop attach automatic unroll
auto_max_depth : int
The maximum depth before stop attach automatic unroll
auto_max_extent : int
The maximum extent of the loop we can unroll.
This is an legacy option that do not take the loop total steps into account.
explicit_unroll : bool
Whether explicitly unroll the loop, or leave unroll annotation to codegen.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.UnrollLoop(
auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll)
def RemoveNoOp():
"""Remove No Op from the Stmt.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.RemoveNoOp()
def RewriteUnsafeSelect():
"""Detect and rewrite unsafe select that contains memory access.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.RewriteUnsafeSelect()
def Simplify():
"""Run arithmetic simplifications on the statements and expressions.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.Simplify()
def InstrumentBoundCheckers():
"""Instruments bound checkers.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InstrumentBoundCheckers()
def LowerCustomDatatypes():
"""Lower custom datatypes.
......
......@@ -25,6 +25,7 @@
#define TVM_ARITH_COMPUTE_EXPR_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <limits>
#include <algorithm>
......
......@@ -109,64 +109,6 @@ void GetBinds(const Array<te::Tensor>& args,
}
}
/*!
* \brief Build a Stmt given a schedule, args and binds. This function runs the IR passes.
* \param sch The schedule to build.
* \param args The arguments for the schedule.
* \param binds Buffer assignments.
* \param loop_partition True if the LoopPartition pass should be included.
* \param out_arg_list Returns the arguments for the Stmt.
* \param config The build configuration.
* \return The built Stmt.
*/
tir::Stmt BuildStmt(te::Schedule sch,
const Array<te::Tensor>& args,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool loop_partition,
Array<ObjectRef> *out_arg_list,
const BuildConfig& config) {
sch = sch.normalize();
// Phase 0
auto bounds = te::InferBound(sch);
auto stmt = te::ScheduleOps(sch, bounds, false);
stmt = tir::InjectPrefetch(stmt);
bool compact = tir::VerifyCompactBuffer(stmt);
Map<te::Tensor, tir::Buffer> out_binds;
GetBinds(args, compact, binds, &out_binds, out_arg_list, config);
// Phase 1
stmt = tir::StorageFlatten(stmt, out_binds, 64,
config->instrument_bound_checkers);
stmt = tir::CanonicalSimplify(stmt);
if (loop_partition) {
stmt = tir::LoopPartition(stmt, config->partition_const_loop);
}
if (config->disable_vectorize) {
stmt = tir::SkipVectorize(stmt);
} else {
stmt = tir::VectorizeLoop(stmt);
}
stmt = tir::InjectVirtualThread(stmt);
stmt = tir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
stmt = tir::StorageRewrite(stmt);
stmt = tir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth,
config->auto_unroll_max_extent, config->unroll_explicit);
// Phase 2
stmt = tir::Simplify(stmt);
stmt = tir::RemoveNoOp(stmt);
if (!(config->disable_select_rewriting))
stmt = tir::RewriteUnsafeSelect(stmt);
if (config->instrument_bound_checkers)
stmt = tir::InstrumentBoundCheckers(stmt);
return stmt;
}
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
return WithAttr(std::move(f), tvm::attr::kTarget, target);
......@@ -176,7 +118,7 @@ transform::Pass BindTarget(Target target) {
template<typename FCond>
transform::Pass FilterBy(FCond fcond) {
transform::Pass Filter(FCond fcond) {
auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (fcond(f)) {
return f;
......@@ -184,18 +126,14 @@ transform::Pass FilterBy(FCond fcond) {
return tir::PrimFunc(nullptr);
}
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {});
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
IRModule lower(te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config) {
Array<ObjectRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
IRModule BuildIRModule(const Array<ObjectRef>& out_arg_list,
tir::Stmt stmt,
const std::string& name,
const BuildConfig& config) {
Array<tir::Var> params;
Map<tir::Var, tir::Buffer> buffer_map;
......@@ -216,10 +154,64 @@ IRModule lower(te::Schedule sch,
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
}
return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
}
IRModule lower(te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config) {
Array<ObjectRef> out_arg_list;
sch = sch.normalize();
// Phase 0
auto bounds = te::InferBound(sch);
auto stmt = te::ScheduleOps(sch, bounds, false);
stmt = tir::InjectPrefetch(stmt);
bool compact = tir::VerifyCompactBuffer(stmt);
Map<te::Tensor, tir::Buffer> out_binds;
GetBinds(args, compact, binds, &out_binds, &out_arg_list, config);
// Phase 1
stmt = tir::StorageFlatten(stmt, out_binds, 64,
config->instrument_bound_checkers);
// convert to IRModule.
auto mod = BuildIRModule(out_arg_list, stmt, name, config);
auto pass_list = Array<tvm::transform::Pass>();
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop));
pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize));
pass_list.push_back(tir::transform::InjectVirtualThread());
pass_list.push_back(tir::transform::InjectDoubleBuffer(config->double_buffer_split_loop));
pass_list.push_back(tir::transform::StorageRewrite());
pass_list.push_back(
tir::transform::UnrollLoop(config->auto_unroll_max_step,
config->auto_unroll_max_depth,
config->auto_unroll_max_extent,
config->unroll_explicit));
// Phase 2
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::RemoveNoOp());
if (!(config->disable_select_rewriting)) {
pass_list.push_back(tir::transform::RewriteUnsafeSelect());
}
if (config->instrument_bound_checkers) {
pass_list.push_back(tir::transform::InstrumentBoundCheckers());
}
// run
auto optimize = transform::Sequential(pass_list);
mod = optimize(std::move(mod));
return mod;
}
std::pair<IRModule, IRModule>
split_dev_host_funcs(IRModule mod_mixed,
const Target& target,
......@@ -242,7 +234,7 @@ split_dev_host_funcs(IRModule mod_mixed,
mod_mixed = opt_mixed(std::move(mod_mixed));
auto host_pass_list = {
FilterBy([](const tir::PrimFunc& f) {
Filter([](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(
tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch;
......@@ -258,7 +250,7 @@ split_dev_host_funcs(IRModule mod_mixed,
// device pipeline
auto device_pass_list = {
FilterBy([](const tir::PrimFunc& f) {
Filter([](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(
tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch;
......
......@@ -114,27 +114,12 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA);
REGISTER_PASS(RewriteUnsafeSelect);
REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform);
REGISTER_PASS(VectorizeLoop);
REGISTER_PASS(SkipVectorize);
REGISTER_PASS(UnrollLoop);
REGISTER_PASS(InjectCopyIntrin);
REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo);
REGISTER_PASS(InjectVirtualThread);
REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(InjectDoubleBuffer);
REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(NarrowDataType);
} // namespace tir
} // namespace tvm
......@@ -22,8 +22,11 @@
*/
// Instrument checkers for out of the bounds access.
#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <vector>
#include <unordered_map>
......@@ -173,8 +176,8 @@ class BoundChecker : public StmtExprMutator {
}
// Try to simplify index and bound.
index = tir::Simplify(index);
upper_bound = tir::Simplify(upper_bound);
index = analyzer_.Simplify(index);
upper_bound = analyzer_.Simplify(upper_bound);
// Cast to the same type - signed, to be able to check lower bound.
index = CastNode::make(DataType::Int(64), index);
......@@ -201,6 +204,8 @@ class BoundChecker : public StmtExprMutator {
const char *const error_message_ = "OUT OF THE BOUNDS";
// Hashtable which maps buffer_var to shape.
std::unordered_map<const VarNode *, PrimExpr> mem_to_shape_;
// internal analyzer
arith::Analyzer analyzer_;
};
Stmt InstrumentBoundCheckers(Stmt stmt) {
......@@ -209,5 +214,29 @@ Stmt InstrumentBoundCheckers(Stmt stmt) {
bound_collector(stmt);
return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt));
}
TVM_REGISTER_GLOBAL("ir_pass.InstrumentBoundCheckers")
.set_body_typed(InstrumentBoundCheckers);
namespace transform {
Pass InstrumentBoundCheckers() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
BoundCollector bound_collector;
// At first walk recursively and collect bound attributes.
bound_collector(n->body);
n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers")
.set_body_typed(InstrumentBoundCheckers);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -20,13 +20,14 @@
/*!
* \file coproc_sync.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
#include "storage_access.h"
#include "../pass/ir_util.h"
#include "../pass/storage_access.h"
namespace tvm {
namespace tir {
......@@ -677,5 +678,24 @@ Stmt CoProcSync(Stmt stmt) {
return CoProcSyncInserter().Insert(std::move(stmt));
}
TVM_REGISTER_GLOBAL("ir_pass.CoProcSync")
.set_body_typed(CoProcSync);
namespace transform {
Pass CoProcSync() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = CoProcSyncInserter().Insert(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {});
}
TVM_REGISTER_GLOBAL("tir.transform.CoProcSync")
.set_body_typed(CoProcSync);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -21,10 +21,11 @@
* \brief Replace certain copy with copy intrinsics.
* \file copy_intrin_rewrite.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/transform.h>
#include <tvm/arith/pattern.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include "../../arith/pattern_match.h"
namespace tvm {
......@@ -196,5 +197,26 @@ Stmt InjectCopyIntrin(Stmt stmt,
return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt));
}
TVM_REGISTER_GLOBAL("ir_pass.InjectCopyIntrin")
.set_body_typed(InjectCopyIntrin);
namespace transform {
Pass InjectCopyIntrin(std::string pragma_key,
PackedFunc flower_copy_fromto) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = CopyIntrinInjector(
pragma_key, flower_copy_fromto)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin")
.set_body_typed(InjectCopyIntrin);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -21,10 +21,12 @@
* \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/op.h>
#include "ir_util.h"
#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
namespace tvm {
......@@ -273,5 +275,26 @@ class DoubleBufferInjector : public StmtExprMutator {
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) {
return DoubleBufferInjector(split_loop).Inject(stmt);
}
TVM_REGISTER_GLOBAL("ir_pass.InjectDoubleBuffer")
.set_body_typed(InjectDoubleBuffer);
namespace transform {
Pass InjectDoubleBuffer(int split_loop) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = DoubleBufferInjector(split_loop).Inject(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer")
.set_body_typed(InjectDoubleBuffer);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -20,8 +20,10 @@
/*!
* \file inject_virtual_thread.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h>
#include <unordered_set>
#include "../../arith/compute_expr.h"
......@@ -500,5 +502,24 @@ Stmt InjectVirtualThread(Stmt stmt) {
return ConvertSSA(std::move(stmt));
}
TVM_REGISTER_GLOBAL("ir_pass.InjectVirtualThread")
.set_body_typed(InjectVirtualThread);
namespace transform {
Pass InjectVirtualThread() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = ConvertSSA(VirtualThreadInjector()(std::move(n->body)));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread")
.set_body_typed(InjectVirtualThread);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -23,9 +23,10 @@
* the body contains the same scope.
* \file lift_attr_scope.cc
*/
#include <tvm/tir/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include "ir_util.h"
#include "../pass/ir_util.h"
namespace tvm {
namespace tir {
......@@ -191,5 +192,24 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
return AttrScopeLifter(attr_key).Lift(std::move(stmt));
}
TVM_REGISTER_GLOBAL("ir_pass.LiftAttrScope")
.set_body_typed(LiftAttrScope);
namespace transform {
Pass LiftAttrScope(std::string attr_key) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope")
.set_body_typed(LiftAttrScope);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -20,9 +20,11 @@
/*!
* \file loop_partition.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/arith/analyzer.h>
#include <unordered_map>
#include <unordered_set>
......@@ -500,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
Stmt pre_stmt;
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = tir::Simplify(middle_interval.min());
body_begin = analyzer_.Simplify(middle_interval.min());
if (!analyzer_.CanProve(body_begin == min)) {
PrimExpr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
......@@ -525,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
Stmt post_stmt;
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = tir::Simplify(middle_interval.max() + 1);
post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
PrimExpr cond = (max - post_doubt_begin + 1 >= 0);
......@@ -588,7 +590,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
for_node->for_type, for_node->device_api, body);
for_node->for_type, for_node->device_api, body);
}
}
......@@ -610,5 +612,25 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
return stmt;
}
TVM_REGISTER_GLOBAL("ir_pass.LoopPartition")
.set_body_typed(LoopPartition);
namespace transform {
Pass LoopPartition(bool split_const_loop) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = LoopPartition(std::move(n->body), split_const_loop);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LoopPartition")
.set_body_typed(LoopPartition);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -143,6 +143,8 @@ Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower()(std::move(stmt));
}
TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccessInfo")
.set_body_typed(LowerStorageAccessInfo);
namespace transform {
......
......@@ -395,6 +395,10 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) {
return DataTypeRewriter(target_bits)(stmt);
}
TVM_REGISTER_GLOBAL("ir_pass.NarrowDataType")
.set_body_typed(NarrowDataType);
namespace transform {
Pass NarrowDataType(int target_bits) {
......
......@@ -21,8 +21,11 @@
* \file remove_no_op.cc
* \brief Remove no op from the stmt
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>
......@@ -147,5 +150,25 @@ class NoOpRemover : public StmtMutator {
Stmt RemoveNoOp(Stmt stmt) {
return NoOpRemover()(std::move(stmt));
}
TVM_REGISTER_GLOBAL("ir_pass.RemoveNoOp")
.set_body_typed(RemoveNoOp);
namespace transform {
Pass RemoveNoOp() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = NoOpRemover()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});
}
TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp")
.set_body_typed(RemoveNoOp);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -21,9 +21,10 @@
* \file unsafe_select_rewrite.cc
* \brief Rewrite uinsafe select expression.
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
......@@ -132,5 +133,24 @@ Stmt RewriteUnsafeSelect(Stmt stmt) {
return UnsafeSelectRewriter()(std::move(stmt));
}
TVM_REGISTER_GLOBAL("ir_pass.RewriteUnsafeSelect")
.set_body_typed(RewriteUnsafeSelect);
namespace transform {
Pass RewriteUnsafeSelect() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = UnsafeSelectRewriter()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {});
}
TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect")
.set_body_typed(RewriteUnsafeSelect);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -18,17 +18,19 @@
*/
/*!
* \file stmt_simplify.cc
* \file simplify.cc
* \brief Statement simplifier based on analyzer
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/arith/analyzer.h>
#include "ir_mutator_with_analyzer.h"
#include "../../arith/ir_mutator_with_analyzer.h"
namespace tvm {
namespace arith {
......@@ -125,5 +127,23 @@ PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange) {
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return CanonicalSimplify(std::move(stmt), vrange);
}
namespace transform {
Pass Simplify() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
arith::Analyzer analyzer;
n->body = arith::StmtSimplifier(&analyzer).Simplify(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
}
TVM_REGISTER_GLOBAL("tir.transform.Simplify")
.set_body_typed(Simplify);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -22,16 +22,18 @@
* \brief Memory access pattern analysis and optimization.
* Re-write data access to enable memory sharing when possible.
*/
#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/target_info.h>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include "ir_util.h"
#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
......@@ -1039,5 +1041,26 @@ Stmt StorageRewrite(Stmt stmt) {
stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true);
return VectorAllocRewriter()(std::move(stmt));
}
TVM_REGISTER_GLOBAL("ir_pass.StorageRewrite")
.set_body_typed(StorageRewrite);
namespace transform {
Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true);
n->body = VectorAllocRewriter()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
}
TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite")
.set_body_typed(StorageRewrite);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -22,8 +22,11 @@
* \file unroll_loop.cc
*/
// Unrolls the loop as in Halide pipeline.
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_set>
#include <unordered_map>
......@@ -201,13 +204,31 @@ Stmt UnrollLoop(Stmt stmt,
}
}
Stmt UnrollLoopExplicitly(Stmt stmt) {
const ForNode* op = stmt.as<ForNode>();
if (!op) {
LOG(FATAL) << "attempted to unroll a non-loop statement";
}
return LoopUnroller(0, 0, 0, false).Unroll(op);
TVM_REGISTER_GLOBAL("ir_pass.UnrollLoop")
.set_body_typed(UnrollLoop);
namespace transform {
Pass UnrollLoop(int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = UnrollLoop(std::move(f->body),
auto_max_step,
auto_max_depth,
auto_max_extent,
explicit_unroll);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {});
}
TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop")
.set_body_typed(UnrollLoop);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -21,9 +21,11 @@
* \file vectorize_loop.cc
*/
// Loop vectorizer as in Halide pipeline.
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <unordered_set>
#include <unordered_map>
......@@ -539,8 +541,9 @@ class VectorizeSkipper : public StmtMutator {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
if (op->for_type == ForType::Vectorized) {
return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
op->body);
return ForNode::make(op->loop_var, op->min, op->extent,
ForType::Serial, op->device_api,
op->body);
} else {
return stmt;
}
......@@ -551,5 +554,32 @@ Stmt SkipVectorize(Stmt stmt) {
return VectorizeSkipper()(std::move(stmt));
}
TVM_REGISTER_GLOBAL("ir_pass.VectorizeLoop")
.set_body_typed(VectorizeLoop);
TVM_REGISTER_GLOBAL("ir_pass.SkipVectorize")
.set_body_typed(SkipVectorize);
namespace transform {
// TODO(tvm-team): Make it as a target property.
Pass VectorizeLoop(bool enable_vectorize) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
if (enable_vectorize) {
n->body = LoopVectorizer()(std::move(n->body));
} else {
n->body = VectorizeSkipper()(std::move(n->body));
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {});
}
TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop")
.set_body_typed(VectorizeLoop);
} // namespace transform
} // namespace tir
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
def test_virtual_thread():
m = te.var('m')
A = te.placeholder((m, ), name='A')
A1 = te.compute((m,), lambda i: A[i], name='A1')
A2 = te.compute((m,), lambda i: A1[i] + 3, name='A2')
s = te.create_schedule(A2.op)
vx = te.thread_axis("vthread", name="vx")
xo, xi = s[A2].split(A2.op.axis[0], nparts=2)
s[A2].bind(xo, vx)
xo, xi = s[A2].split(xi, 8)
s[A1].compute_at(s[A2], xo)
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
stmt = tvm.tir.ir_pass.Simplify(stmt)
stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt)
print(stmt)
if __name__ == "__main__":
test_virtual_thread()
......@@ -37,7 +37,10 @@ def test_coproc_sync():
ib.scope_attr(cp, "coproc_scope", 1)
A[j] = A[j + k * 10] + 2
stmt = ib.get()
stmt = tvm.tir.ir_pass.CoProcSync(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
body = stmt.body.body.body
blist = tvm.tir.stmt_list(body)
assert(blist[1].value.name == "cop.coproc_read_barrier")
......@@ -65,7 +68,10 @@ def test_coproc_sync2():
ib.scope_attr(cp, "coproc_scope", 2)
A[ty] = 1.0
stmt = ib.get()
stmt = tvm.tir.ir_pass.CoProcSync(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
def test_coproc_sync3():
def __check_list(tvm_array, py_list):
......@@ -91,7 +97,10 @@ def test_coproc_sync3():
A[0] = 0.0
stmt = ib.get()
stmt = tvm.tir.ir_pass.CoProcSync(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
slist = tvm.tir.stmt_list(stmt[0].body.body)
push_st = slist[2]
slist = tvm.tir.stmt_list(slist[-1])
......
......@@ -35,7 +35,10 @@ def test_copy2d():
assert src.strides[0] == l
assert tuple(src.shape) == (m, l)
return tvm.tir.Evaluate(0)
stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
def test_copy_pad():
m = te.var('m')
......@@ -59,7 +62,10 @@ def test_copy_pad():
assert pad_after[1].value == 0
assert pad_value.value == 1.0
return tvm.tir.Evaluate(0)
stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
def test_single_point_test():
A = te.placeholder((1,), name='A')
......@@ -78,7 +84,10 @@ def test_single_point_test():
assert tvm.tir.ir_pass.Simplify(src.strides[0]).value == 1
assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.tir.Evaluate(0)
stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
def assert_expr_equal(a, b):
assert tvm.tir.ir_pass.Simplify(a - b).value == 0
......@@ -111,7 +120,11 @@ def test_copy_pad_split():
assert_expr_equal(pad_after[0], rpad_after)
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.tir.Evaluate(0)
stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
if __name__ == "__main__":
......
......@@ -36,13 +36,19 @@ def test_double_buffer():
C[j] = B[j] + 1
stmt = ib.get()
stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2)
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
mod = tvm.IRModule({
"db" : tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt)
})
opt = tvm.transform.Sequential(
[tvm.tir.transform.InjectDoubleBuffer(2),
tvm.tir.transform.Simplify()])
mod = opt(mod)
stmt = mod["db"].body
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
def count_sync(op):
......
......@@ -40,9 +40,14 @@ def test_vthread():
C[i * nthread + tx] = B[i] + 1
return ib.get()
stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread"))
stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
tvm.tir.PrimFunc([], get_vthread("vthread"))))["main"].body
assert stmt.body.body.extents[0].value == 2
stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("cthread"))
stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body
assert len(stmt.body.body.extents) == 3
......@@ -67,16 +72,20 @@ def test_vthread_extern():
A[tx] = tx + 1.0
B[ty] = ty + 1.0
ib.emit(tvm.tir.call_extern("int32", "Run",
abuffer.access_ptr("r"),
bbuffer.access_ptr("r"),
cbuffer.access_ptr("rw")))
abuffer.access_ptr("r"),
bbuffer.access_ptr("r"),
cbuffer.access_ptr("rw")))
return ib.get()
stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread"))
stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body
assert stmt.body.body.extents[0].value == 2
assert stmt.body.body.body.body.body.body.extents[0].value == 2
assert len(stmt.body.body.body.body.body.body.extents) == 3
def test_vthread_if_then_else():
nthread = 2
tx = te.thread_axis("vthread")
......@@ -92,7 +101,10 @@ def test_vthread_if_then_else():
with ib.if_scope(i == 0):
B[i] = A[i * nthread + tx] + 2
stmt = ib.get()
stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt)
stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
tvm.tir.PrimFunc([], stmt)))["main"].body
assert stmt.body.body.body[0].else_case != None
assert stmt.body.body.body[1].else_case == None
......
......@@ -18,32 +18,12 @@ import pytest
import tvm
from tvm import te
import numpy as np
def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
return ret
def lower(sch, args):
binds = {}
arg_list = []
for x in args:
if isinstance(x, te.tensor.Tensor):
buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
else:
raise ValueError("args must be Tensor, Buffer or Var")
sch = sch.normalize()
bounds = tvm.te.schedule.InferBound(sch)
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
stmt = tvm.tir.ir_pass.RemoveNoOp(stmt)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, True)
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
stmt = tvm.tir.ir_pass.Simplify(stmt)
return stmt
@pytest.mark.xfail
def test_out_of_bounds_llvm(index_a, index_b):
......@@ -72,7 +52,6 @@ def test_in_bounds_llvm():
tgt = "llvm"
tgt_host = "llvm"
stmt = tvm.lower (s, [A, B, C], simple_mode=True)
print (stmt)
fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
ctx = tvm.context(tgt, 0)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
......@@ -93,7 +72,6 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b):
tgt = "llvm"
tgt_host = "llvm"
stmt = tvm.lower (s, [a, b, c], simple_mode=True)
print (stmt)
f = tvm.build(s, [a, b, c], tgt, target_host=tgt_host, name="myaddvec")
ctx = tvm.cpu(0)
n = nn
......@@ -192,13 +170,11 @@ def test_in_bounds_const_loop_partition_ir():
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
bounds = tvm.te.schedule.InferBound(s)
stmt = lower (s, [A, B, T])
# num_attributes = num_buffers * num_splits = 2 * 3
# before instrumentation
assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
assert_bound_instrumentation(stmt, check_branch_stmt, 0)
stmt = tvm.tir.ir_pass.InstrumentBoundCheckers(stmt)
with tvm.target.build_config(instrument_bound_checkers=True,
partition_const_loop=True):
mod = tvm.driver.lower(s, [A, B, T], name="main")
stmt = mod["main"].body
# after instrumentation
assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
assert_bound_instrumentation(stmt, check_branch_stmt, 2)
......@@ -209,7 +185,8 @@ def test_in_bounds_const_loop_partition_ir():
def test_in_bounds_const_loop_partition_llvm():
with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True):
with tvm.target.build_config(instrument_bound_checkers=True,
partition_const_loop=True):
n = 21
A = te.placeholder((n, ), name='A')
B = te.placeholder((n, ), name='B')
......
......@@ -35,7 +35,10 @@ def test_coproc_lift():
A[j] = A[j] + 3
A[j] = A[j] + 3
body = ib.get()
body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body
assert body.body.body.node == cp
# only able to lift to the common pattern of the last two fors.
......@@ -52,7 +55,10 @@ def test_coproc_lift():
A[i] = A[i] + 2
body = ib.get()
body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body
assert body.body.body.body[1].node == cp
assert len(body.body.body.body) == 2
......
......@@ -36,16 +36,24 @@ def test_remove_no_op():
k, 0, m, 0, 0,
tvm.tir.IfThenElse(
(i*m+j+k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n)))))
ret = tvm.tir.ir_pass.RemoveNoOp(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
assert(isinstance(ret, tvm.tir.Evaluate))
store = tvm.tir.Store(Ab.data,
tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1)
stmt2 = tvm.tir.SeqStmt([nop(), tvm.tir.SeqStmt([store, nop()])])
assert(tvm.tir.ir_pass.RemoveNoOp(stmt2) == store)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt2))
ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
assert(ret == store)
# remove zero extent loop
stmt3 = tvm.tir.For(i, 0, 0, 0, 0, store)
ret = tvm.tir.ir_pass.RemoveNoOp(stmt3)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt3))
ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
assert(isinstance(ret, tvm.tir.Evaluate))
......
......@@ -23,14 +23,22 @@ def test_rewrite_Select():
A = ib.allocate("float32", 100, name="A", scope="global")
i = te.var("i")
y = tvm.tir.Select(i > 1, A[i-1], 1.0)
yy = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(y)).value
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([i], tvm.tir.Evaluate(y)))
yy = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
z = tvm.tir.Select(
tvm.tir.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
zz = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(z)).value
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([i], tvm.tir.Evaluate(z)))
zz = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
a = tvm.tir.Select(tvm.tir.floordiv(i, 4) > 10, y, z)
a = tvm.tir.Select(tvm.te.floordiv(i, 4) > 10, y, z)
aa = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(a)).value
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a)))
aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else"
assert isinstance(aa, tvm.tir.Select)
......
......@@ -27,7 +27,9 @@ def test_stmt_simplify():
A[i] = C[i]
body = tvm.tir.LetStmt(n, 10, ib.get())
body = tvm.tir.ir_pass.CanonicalSimplify(body)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A, C, n], body))
body = tvm.tir.transform.Simplify()(mod)["main"].body
assert isinstance(body.body, tvm.tir.Store)
......@@ -44,7 +46,9 @@ def test_thread_extent_simplify():
with ib.if_scope(tx + ty < 12):
A[tx] = C[tx + ty]
body = tvm.tir.LetStmt(n, 10, ib.get())
body = tvm.tir.ir_pass.CanonicalSimplify(body)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A, C, n], body))
body = tvm.tir.transform.Simplify()(mod)["main"].body
assert isinstance(body.body.body.body, tvm.tir.Store)
......
......@@ -33,9 +33,12 @@ def test_storage_share():
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.tir.ir_pass.Simplify(stmt)
stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
......@@ -72,7 +75,10 @@ def test_alloc_seq():
A[j] = 1.3
body = ib.get()
body = tvm.tir.ir_pass.StorageRewrite(body)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
......@@ -129,7 +135,10 @@ def test_alloc_different_dtypes():
body = stmt_generater(dtype_list, length)
offset = offset_generater(dtype_list, length)
body = tvm.tir.ir_pass.StorageRewrite(body)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
tvm.tir.ir_pass.PostOrderVisit(body, verify)
length = 1024
......@@ -160,9 +169,12 @@ def test_inplace_rule():
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.tir.ir_pass.Simplify(stmt)
stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
......@@ -192,9 +204,12 @@ def test_storage_combine():
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.tir.ir_pass.Simplify(stmt)
stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
......@@ -226,9 +241,12 @@ def test_storage_share_gpu():
Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A')
Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64)
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.tir.ir_pass.Simplify(stmt)
stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
alloc_stats = {"global": 0, "shared": 0}
def verify(n):
......@@ -248,7 +266,9 @@ def test_parallel_alloc():
A[j] = A[j] + 2
body = ib.get()
body = tvm.tir.ir_pass.StorageRewrite(body)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
assert (isinstance(body.body.body, tvm.tir.Allocate))
ib = tvm.tir.ir_builder.create()
......@@ -262,7 +282,9 @@ def test_parallel_alloc():
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
body = ib.get()
body = tvm.tir.ir_pass.StorageRewrite(body)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
assert(isinstance(body.body.body.body.body, tvm.tir.Allocate))
......@@ -289,9 +311,12 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C')
Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64)
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.tir.ir_pass.Simplify(stmt)
stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb, Cc, Dd], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
......@@ -381,10 +406,13 @@ def test_inplace_rule3():
B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B2a, B4: B4a, B5: B5a, B: Bb}, 64)
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.tir.ir_pass.Simplify(stmt)
stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B3a, B4: B4a, B5: B5a, B: Bb}, 64)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([B0a, B1a, B2a, B3a, B4a, B5a, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
# verify only have one allocations.
# verify inplace folding works
def verify(n):
......@@ -411,7 +439,10 @@ def test_alloc_seq_type():
A2[j] = A[j]
body = ib.get()
body = tvm.tir.ir_pass.StorageRewrite(body)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
......@@ -440,7 +471,10 @@ def test_alloc_seq_type2():
C[j] = 1.2
body = ib.get()
body = tvm.tir.ir_pass.StorageRewrite(body)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
......@@ -469,7 +503,9 @@ def test_reuse_small_buffer():
E[j] = C[j]
body = ib.get()
body = tvm.tir.ir_pass.StorageRewrite(body)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
num_alloc = [0]
......@@ -519,14 +555,15 @@ def test_large_input():
if __name__ == "__main__":
test_storage_share()
test_alloc_seq()
test_alloc_different_dtypes()
test_inplace_rule()
test_storage_share()
test_parallel_alloc()
test_storage_combine()
test_storage_share_gpu()
test_inplace_rule2()
test_exceed_mem()
test_inplace_rule3()
test_alloc_seq_type()
......
......@@ -46,7 +46,11 @@ def test_unroll_loop():
wrapped = ib.get()
wrapped = tvm.tir.SeqStmt([wrapped, stmt])
assert isinstance(ret, tvm.tir.For)
ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped))
ret = tvm.tir.transform.UnrollLoop(0, 8, 0, False)(mod)["main"].body
# ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
assert isinstance(ret[0], tvm.tir.For)
assert ret[0].for_type == tvm.tir.For.Unrolled
assert isinstance(ret[1], tvm.tir.For)
......@@ -65,7 +69,11 @@ def test_unroll_fake_loop():
Aptr[j + 1] = Aptr[i] + 1
stmt = ib.get()
ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
ret = tvm.tir.transform.UnrollLoop(8, 0, 1, False)(mod)["main"].body
# ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
assert isinstance(ret[0], tvm.tir.Store)
def test_unroll_single_count_loops():
......@@ -78,8 +86,10 @@ def test_unroll_single_count_loops():
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
# all parameters to UnrolLoops are default values except for
# auto_unroll_max_extent which has been set to 1 (default:0)
after_unroll_stmt = tvm.tir.ir_pass.UnrollLoop(stmt, 0, 8, 1, True)
assert after_unroll_stmt == stmt
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
ret = tvm.tir.transform.UnrollLoop(0, 8, 1, True)(mod)["main"].body
assert ret == stmt
if __name__ == "__main__":
test_unroll_loop()
......
......@@ -28,12 +28,16 @@ def test_vectorize_loop():
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
assert isinstance(stmt, tvm.tir.For)
assert not isinstance(stmt.body, tvm.tir.For)
assert isinstance(stmt.body.index, tvm.tir.Ramp)
assert isinstance(stmt.body.value, tvm.tir.Broadcast)
def test_vectorize_vector():
dtype = 'int64'
n = te.var('n')
......@@ -44,7 +48,10 @@ def test_vectorize_vector():
A[j] = tvm.tir.const(1, A.dtype)
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
assert isinstance(stmt, tvm.tir.For)
assert not isinstance(stmt.body, tvm.tir.For)
assert isinstance(stmt.body.index, tvm.tir.Ramp)
......@@ -63,13 +70,17 @@ def test_vectorize_with_if():
with ib.if_scope(i < n):
A[i] = 2.0
stmt = ib.get()
stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
assert isinstance(stmt, tvm.tir.IfThenElse)
assert isinstance(stmt.then_case.index, tvm.tir.Ramp)
assert isinstance(stmt.then_case.value, tvm.tir.Add)
assert stmt.then_case.value.dtype == "float32x4"
assert isinstance(stmt.else_case, tvm.tir.For)
def test_vectorize_with_le_cond():
n = te.var('n')
ib = tvm.tir.ir_builder.create()
......@@ -78,9 +89,13 @@ def test_vectorize_with_le_cond():
with ib.if_scope(i <= n):
A[i] = A[i] + 1
stmt = ib.get()
stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
assert isinstance(stmt, tvm.tir.For)
def test_vectorize_with_ge_cond():
n = te.var('n')
ib = tvm.tir.ir_builder.create()
......@@ -89,9 +104,13 @@ def test_vectorize_with_ge_cond():
with ib.if_scope(i >= n):
A[i] = A[i] + 1
stmt = ib.get()
stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
assert isinstance(stmt, tvm.tir.For)
def test_vectorize_if_then_else():
n = te.var('n')
x = te.var('x')
......@@ -102,7 +121,10 @@ def test_vectorize_if_then_else():
i > 0,
A[i] + 1, A[i])
stmt = ib.get()
stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
assert isinstance(stmt, tvm.tir.For)
......@@ -114,8 +136,12 @@ def test_vectorize_if_then_else():
k > 0,
A[k * 4 + i], 0)
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
assert not isinstance(stmt.body, tvm.tir.For)
assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment