Unverified Commit c3511c5e by Tianqi Chen Committed by GitHub

[TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes. (#5372)

* [TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes.

te::Tensor is an useful object for tensor expression, but brings
un-necessary reverse dependency in TIR nodes such as Provide and Realize.

This PR is a first step to remove this dependency. We will use Buffer in all the places
where the te::Tensor was used. The rough correspondence are:

- Provide -> BufferStore
- Realize -> BufferRealize
- HalideCall -> BufferLoad.

After this change, we can not use IRModule of PrimFuncs cleanly to represent TIR
at any point of the optimizations. Buffer will serve as the abstraction for the TIR data
models to represent the intermediate storages and their constraints.

We still keep Realize/HalideCall and Provide as TIR nodes for now to make the change minimum.
Right after ScheduleOps, we call SchedulePostProcToPrimFunc to canonicalize the temporary IR
generated by TE(which contains these nodes) to the TIR.

The TIR optimizations are now mostly migrated to to the pass manager.
Followup PRs are needed to migrate the remaining few passes.

* Fix dev tutorial
parent a4902e05
......@@ -78,15 +78,15 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
/*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
* \param body The given statement.
* \param tensor The name of the calls or provides.
* \param consider_calls If calls (read) are considered.
* \param consider_provides If provides (write) are considered.
* \param buffer The buffer to check the access info.
* \param consider_loads If loads are considered.
* \param consider_stores If stores are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
Domain DomainTouched(Stmt body,
const te::Tensor &tensor,
bool consider_calls,
bool consider_provides);
Domain DomainTouched(const Stmt& body,
const tir::Buffer& buffer,
bool consider_loads,
bool consider_stores);
} // namespace arith
} // namespace tvm
......
......@@ -70,7 +70,7 @@ class ObjAllocatorBase {
static_assert(std::is_base_of<Object, T>::value,
"make can only be used to create Object");
T* ptr = Handler::New(static_cast<Derived*>(this),
std::forward<Args>(args)...);
std::forward<Args>(args)...);
ptr->type_index_ = T::RuntimeTypeIndex();
ptr->deleter_ = Handler::Deleter();
return ObjectPtr<T>(ptr);
......
......@@ -29,6 +29,7 @@
#define TVM_TE_SCHEDULE_PASS_H_
#include <tvm/te/schedule.h>
#include <tvm/tir/function.h>
namespace tvm {
namespace te {
......@@ -55,6 +56,26 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
/*!
* \brief Postprocessing the Stmt generated by ScheduleOps to create
* a PrimFunc that can then be used for further TIR optimizations.
*
* Perform this translation before running any TIR optimizations.
*
* List of actions taken by the function:
* - Remove occurences of te::Tensor, te::Operation in the IR
* and replace them by corresponding IR nodes via tir::Buffer.
* - Add annotation of extern buffers using the buffer_map field
* in the PrimFunc type.
*
* \param arg_list Array of Tensor/Var/Buffer arguments to the function.
* \param body The body of the function.
* \param bindings potential Tensor to Buffer bindings for the Tensors in the body.
*/
PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
Stmt body,
Optional<Map<Tensor, Buffer>> bindings);
/*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
......
......@@ -694,7 +694,10 @@ class CallNode : public PrimExprNode {
ExternCPlusPlus = 1,
/*! \brief Extern "C" without side-effect. */
PureExtern = 2,
/*! \brief Halide-style call, evaluates func(args). */
/*!
* \brief Halide-style call, evaluates func(args).
* \note Deprecated, move to BufferLoad in the future.
*/
Halide = 3,
/*! \brief Intrinsic functions. */
Intrinsic = 4,
......@@ -707,9 +710,15 @@ class CallNode : public PrimExprNode {
Array<PrimExpr> args;
/*! \brief Type of calls. */
CallType call_type;
/*! \brief The function to be called. */
/*!
* \brief The function to be called.
* \note Deprecated, move to BufferLoad in the future.
*/
FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */
/*!
* \brief The output value index if func's value is a tuple.
* \note Deprecated, move to BufferLoad in the future.
*/
int value_index{0};
void VisitAttrs(AttrVisitor* v) {
......
......@@ -165,22 +165,6 @@ Stmt Inline(Stmt stmt,
PrimExpr body);
/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
*
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \param cache_line_size The size of CPU cache line.
* \param create_bound_attribute Whether to create bound attributes.
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<te::Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);
/*!
* \brief Try to modify the AST to support TensorCore
*
* \param stmt The stmt to be trasnformed.
......@@ -203,13 +187,6 @@ Stmt RewriteForTensorCore(Stmt stmt,
bool VerifyCompactBuffer(Stmt stmt);
/*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectPrefetch(Stmt stmt);
/*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
......
......@@ -248,7 +248,6 @@ class StoreNode : public StmtNode {
* \endcode
* \sa BufferLoad
*/
class BufferStore;
class BufferStoreNode : public StmtNode {
public:
/*! \brief The buffer variable. */
......@@ -281,6 +280,10 @@ class BufferStoreNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
};
/*!
* \brief Managed reference to BufferStoreNode.
* \sa BufferStoreNode
*/
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer,
......@@ -290,7 +293,79 @@ class BufferStore : public Stmt {
};
/*!
* \brief Annotate the region where the buffer need to
* be read and write in the body.
* We only need to allocate the space for the corresponding region.
*
* \note There should be at most one BufferRealize for each buffer.
* BufferRealize is not necessary for external buffers,
* since they are assumed to be fully allocated.
*
* \sa BufferLoad, BufferStore
*/
class BufferRealizeNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Buffer buffer;
/*! \brief Bounds to be realized */
Array<Range> bounds;
/*! \brief Only realize if condition holds. */
PrimExpr condition;
/*! \brief The body of realization. */
Stmt body;
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
}
bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
return
equal(buffer, other->buffer) &&
equal(bounds, other->bounds) &&
equal(condition, other->condition) &&
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
}
BufferRealizeNode() = default;
BufferRealizeNode(Buffer buffer,
Array<Range> bounds,
PrimExpr condition,
Stmt body)
: buffer(buffer), bounds(bounds),
condition(condition), body(body) {}
static constexpr const char* _type_key = "BufferRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode);
};
/*!
* \brief Managed reference to BufferRealizeNode.
* \sa BufferRealizeNode
*/
class BufferRealize : public Stmt {
public:
TVM_DLL explicit BufferRealize(Buffer buffer,
Array<Range> bounds,
PrimExpr condition,
Stmt body);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode);
};
/*!
* \brief Store value into mult-dimensional array defined by func.
*
* \note Deprecated, move to BufferStore in the future.
*/
class ProvideNode : public StmtNode {
public:
......@@ -430,6 +505,8 @@ class FreeNode : public StmtNode {
/*!
* \brief Annotate the bounds where func need to be written and read in body.
* We will need to allocate space for the corresponding regions.
*
* \note Deprecated, move to BufferRealize in the future.
*/
class RealizeNode : public StmtNode {
public:
......@@ -747,51 +824,51 @@ class ForNode : public StmtNode {
};
/*!
* \brief A prefetch hint of func.
* \brief A prefetch hint for abuffer
*/
class PrefetchNode : public StmtNode {
public:
/*! \brief The function to be prefetched. */
FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */
int value_index;
/*! \brief The data type of the array. */
DataType dtype;
Buffer buffer;
/*! \brief Bounds to be prefetched. */
Region bounds;
Array<Range> bounds;
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
v->Visit("dtype", &dtype);
v->Visit("buffer", &buffer);
v->Visit("bounds", &bounds);
}
bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(dtype, other->dtype) &&
equal(buffer, other->buffer) &&
equal(bounds, other->bounds);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(value_index);
hash_reduce(dtype);
hash_reduce(buffer);
hash_reduce(bounds);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
DataType dtype,
Region bounds);
PrefetchNode() = default;
PrefetchNode(Buffer buffer, Array<Range> bounds)
: buffer(buffer), bounds(bounds) {}
static constexpr const char* _type_key = "Prefetch";
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
};
/*!
* \brief Managed reference to PrefetchNode.
* \sa PrefetchNode
*/
class Prefetch : public Stmt {
public:
TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
};
/*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
struct TensorKey {
......
......@@ -92,6 +92,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
......@@ -121,6 +122,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
return vtable;
}
};
......@@ -154,6 +157,7 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
void VisitStmt_(const FreeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProvideNode* op) override;
......@@ -248,6 +252,7 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Stmt VisitStmt_(const FreeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProvideNode* op) override;
......
......@@ -58,6 +58,27 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const std::string& name,
const tvm::Array<runtime::String>& required);
/*!
* \brief Inject prefetch instructions into stmt.
*
* \return The pass.
*/
TVM_DLL Pass InjectPrefetch();
// TODO(tvm-team): consolidate configs to the PassContext
/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
*
* \param cache_line_size The size of CPU cache line.
* \param create_bound_attribute Whether to create bound attributes.
*
* \return The Pass
*/
TVM_DLL Pass StorageFlatten(int cache_line_size,
bool create_bound_attribute = false);
/*!
* \brief Inject copy intrinsics with optional pad.
*
......
......@@ -31,7 +31,6 @@ import numpy as np
import tvm._ffi
from tvm import target as _target
from tvm.tir import ir_pass
from tvm.te import schedule
from tvm.driver import build_module
......@@ -46,10 +45,12 @@ def ana_lower(sch, args,
# Phase 0
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds, True)
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.CanonicalSimplify(stmt)
func = schedule.SchedulePostProcToPrimFunc(args, stmt, None)
mod = tvm.IRModule.from_expr(func._move())
mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
mod = tvm.tir.transform.Simplify()(mod._move())
assert simple_mode
return stmt
return mod["main"].body
try:
_get_buffer_curve_sample_flatten = tvm._ffi.get_global_func(
......
......@@ -85,7 +85,8 @@ def get_binds(args, compact=False, binds=None):
def form_body(sch):
"""According to the given schedule, form the raw body
"""According to the given schedule, form a function.
Parameters
----------
sch : tvm.te.schedule.Schedule
......@@ -99,13 +100,31 @@ def form_body(sch):
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt
def _wrap_as_prim_func_pass(flist, name):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def _transform(func, *_):
stmt = func.body
for f in flist:
stmt = f(stmt)
# create a new function with updated body.
return tvm.tir.PrimFunc(func.params,
stmt,
func.ret_type,
func.buffer_map,
func.attrs)
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)
def lower(sch,
args,
name="default_function",
name="main",
binds=None,
simple_mode=False):
"""Lowering step before build into target.
......@@ -154,56 +173,57 @@ def lower(sch,
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
# Start the new style pass manager.
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
func = func.with_attr("global_symbol", name)
if cfg.restricted_func:
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: func})
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.NarrowDataType(stmt, 32)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
pass_list = [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
_wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
]
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt)
else:
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.UnrollLoop(
stmt,
cfg.auto_unroll_max_step,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)
for f in lower_phase2:
stmt = f(stmt)
pass_list += [(tvm.tir.transform.LoopPartition(cfg.partition_const_loop))]
pass_list += [
tvm.tir.transform.VectorizeLoop(not cfg.disable_vectorize),
tvm.tir.transform.InjectVirtualThread(),
tvm.tir.transform.InjectDoubleBuffer(cfg.double_buffer_split_loop),
tvm.tir.transform.StorageRewrite(),
tvm.tir.transform.UnrollLoop(
cfg.auto_unroll_max_step,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit),
_wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
]
# Phase 3
stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
pass_list += [
tvm.tir.transform.Simplify(),
tvm.tir.transform.RemoveNoOp(),
]
for f in lower_phase3:
stmt = f(stmt)
if not cfg.disable_select_rewriting:
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")]
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
pass_list += [tvm.tir.transform.InstrumentBoundCheckers()]
if simple_mode:
return stmt
f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
if cfg.restricted_func:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
optimize = tvm.transform.Sequential(pass_list)
mod = optimize(mod)
return mod
......
......@@ -157,11 +157,6 @@ class Sequential(Pass):
"""A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class.
Some typical usage of the sequential pass are:
1. Users provide a list of passes for optimization.
2. Only an optimization level is provided so that the backend system has
to glob all passes at this level and below to perform the optimizations.
Note that users can also provide a series of passes that they don't want to
apply when running a sequential pass. Pass dependency will be resolved in
the backend as well.
......@@ -173,6 +168,9 @@ class Sequential(Pass):
opt_level : Optional[int]
The optimization level of this sequential pass.
The opt_level of a default sequential pass is set to 0.
Note that some of the passes within the Sequantial may still not be executed
if their opt_level is higher than the provided opt_level.
name : Optional[str]
The name of the sequential pass.
......
......@@ -28,7 +28,8 @@ from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar, Any
from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import BufferStore, BufferRealize, Store, Provide, Allocate, AttrStmt
from .stmt import Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .function import PrimFunc
......
......@@ -161,6 +161,29 @@ class BufferStore(Stmt):
@tvm._ffi.register_object
class BufferRealize(Stmt):
"""Buffer realize node.
Parameters
----------
buffer : Buffer
The buffer.
bounds : List[Range]
The value we to be stored.
condition : PrimExpr
The realize condition.
body : Stmt
The body of the statement.
"""
def __init__(self, buffer, bounds, condition, body):
self.__init_handle_by_constructor__(
_ffi_api.BufferRealize, buffer, bounds, condition, body)
@tvm._ffi.register_object
class Provide(Stmt):
"""Provide node.
......@@ -348,21 +371,15 @@ class Prefetch(Stmt):
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
dtype : str
The data type to be prefetched.
buffer : Buffer
The buffer to be prefetched.
bounds : list of Range
The bounds to be prefetched.
"""
def __init__(self, func, value_index, dtype, bounds):
def __init__(self, buffer, bounds):
self.__init_handle_by_constructor__(
_ffi_api.Prefetch, func, value_index, dtype, bounds)
_ffi_api.Prefetch, buffer, bounds)
def stmt_seq(*args):
......
......@@ -60,6 +60,38 @@ def Filter(fcond):
return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
def InjectPrefetch():
"""Inject prefetch instructions into stmt.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectPrefetch()
def StorageFlatten(cache_line_size, create_bound_attribute=False):
"""Flatten the multi-dimensional read/write to 1D.
Parameters
----------
cache_line_size: int
The size of CPU cache line.
create_bound_attribute:
Whether to create bound attributes.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute)
def InjectCopyIntrin(pragma_key, fintrin):
"""Inject virtual thread loops.
......
......@@ -36,10 +36,14 @@ namespace arith {
using namespace tir;
// Find Read region of the tensor in the stmt.
class FuncTouchedDomain final : public StmtExprVisitor {
class BufferTouchedDomain final : public StmtExprVisitor {
public:
FuncTouchedDomain(const te::Tensor &tensor, bool consider_calls, bool consider_provides)
: tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {}
BufferTouchedDomain(const Buffer &buffer,
bool consider_loads,
bool consider_stores)
: buffer_(buffer),
consider_loads_(consider_loads),
consider_stores_(consider_stores) {}
Domain Find(const Stmt& stmt) {
operator()(stmt);
......@@ -80,18 +84,16 @@ class FuncTouchedDomain final : public StmtExprVisitor {
}
}
void VisitExpr_(const CallNode* op) final {
if (consider_calls_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
void VisitExpr_(const BufferLoadNode* op) final {
if (consider_loads_ && buffer_.same_as(op->buffer)) {
Touch(op->indices);
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const ProvideNode* op) final {
if (consider_provides_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
void VisitStmt_(const BufferStoreNode* op) final {
if (consider_stores_ && buffer_.same_as(op->buffer)) {
Touch(op->indices);
}
StmtExprVisitor::VisitStmt_(op);
}
......@@ -106,17 +108,17 @@ class FuncTouchedDomain final : public StmtExprVisitor {
}
}
const te::Tensor &tensor_;
bool consider_calls_, consider_provides_;
const Buffer &buffer_;
bool consider_loads_, consider_stores_;
std::vector<std::vector<IntSet> > bounds_;
std::unordered_map<const VarNode*, IntSet> dom_map_;
};
Domain DomainTouched(Stmt stmt,
const te::Tensor &tensor,
bool consider_calls,
bool consider_provides) {
return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt);
Domain DomainTouched(const Stmt& stmt,
const Buffer& buffer,
bool consider_loads,
bool consider_stores) {
return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt);
}
TVM_REGISTER_GLOBAL("arith.DomainTouched")
......
......@@ -130,35 +130,6 @@ transform::Pass Filter(FCond fcond) {
}
IRModule BuildIRModule(const Array<ObjectRef>& out_arg_list,
tir::Stmt stmt,
const std::string& name,
const BuildConfig& config) {
Array<tir::Var> params;
Map<tir::Var, tir::Buffer> buffer_map;
for (auto var : out_arg_list) {
if (auto* n = var.as<tir::VarNode>()) {
params.push_back(GetRef<tir::Var>(n));
} else {
tir::Buffer buffer = Downcast<tir::Buffer>(var);
tir::Var bptr(buffer->name, DataType::Handle());
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
}
}
auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map);
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
}
return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
}
IRModule lower(te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
......@@ -168,23 +139,31 @@ IRModule lower(te::Schedule sch,
sch = sch.normalize();
// Phase 0
// Before TIR transformation.
auto bounds = te::InferBound(sch);
auto stmt = te::ScheduleOps(sch, bounds, false);
stmt = tir::InjectPrefetch(stmt);
bool compact = tir::VerifyCompactBuffer(stmt);
Map<te::Tensor, tir::Buffer> out_binds;
GetBinds(args, compact, binds, &out_binds, &out_arg_list, config);
// Phase 1
stmt = tir::StorageFlatten(stmt, out_binds, 64,
config->instrument_bound_checkers);
// build the function
tir::PrimFunc f = te::SchedulePostProcToPrimFunc(
out_arg_list, std::move(stmt), out_binds);
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
}
// convert to IRModule.
auto mod = BuildIRModule(out_arg_list, stmt, name, config);
auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
auto pass_list = Array<tvm::transform::Pass>();
// Phase 0
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(
tir::transform::StorageFlatten(64, config->instrument_bound_checkers));
// Phase 1
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop));
pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize));
......
......@@ -132,8 +132,8 @@ MakeLoopNest(const Stage& stage,
for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
nest[i + 1].emplace_back(
AttrStmtNode::make(it_attr->prefetch_data[j],
tir::attr::prefetch_scope,
it_attr->prefetch_offset[j], no_op));
tir::attr::prefetch_scope,
it_attr->prefetch_offset[j], no_op));
}
}
} else if (bind_iv->thread_tag == "vthread" ||
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file schedule_postproc_to_primfunc.cc
*
* \brief Translate the function body generated by ScheduleOps
* with te related dialects that incorporates Tensor
* into the Stmts to a PrimFunc.
*
* Perform this translation before running any TIR optimizations.
*
* Rationale: The body generated by ScheduleOps is not
* a formal PrimFunc and cannot be used for further optimization.
* This function canonicalize that body and creates a formal PrimFunc.
*
* List of actions taken by the function:
* - Remove occurences of te::Tensor, te::Operation in the IR
* and replace them by corresponding IR nodes via tir::Buffer.
* - Add annotation of extern buffers using the buffer_map field
* in the PrimFunc type.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/te/operation.h>
#include <utility>
#include <unordered_map>
namespace tvm {
namespace te {
// create a buffer for tensor.
Buffer CreateBufferFor(const Tensor& tensor) {
std::string name = tensor->op->name;
if (tensor->op->num_outputs() != 1) {
name += ".v" + std::to_string(tensor->value_index);
}
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name);
return buffer;
}
// A remapper that maps tensor to buffer
class TensorToBufferMapper : public StmtExprMutator {
public:
explicit TensorToBufferMapper(std::unordered_map<Tensor, Buffer> buffer_map)
: buffer_map_(buffer_map) {
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
auto ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
// TODO(tvm-team): remove realize_scope, turn the info into
// Buffer's scope field in this pass.
if (op->attr_key == tir::attr::realize_scope ||
op->attr_key == tir::attr::double_buffer_scope) {
Stmt body = op->body;
Operation operation = Downcast<Operation>(op->node);
for (int i = operation->num_outputs(); i != 0; --i) {
Buffer buffer = GetOrAllocBuffer(operation.output(i - 1));
body = AttrStmtNode::make(
buffer, op->attr_key, op->value, body);
}
return body;
} else if (op->attr_key == tir::attr::buffer_bind_scope) {
Array<ObjectRef> tuple = Downcast<Array<ObjectRef> >(op->node);
Tensor tensor = Downcast<Tensor>(tuple[1]);
return AttrStmtNode::make(
Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)},
op->attr_key, op->value, op->body);
} else if (op->attr_key == tir::attr::buffer_dim_align||
op->attr_key == tir::attr::prefetch_scope) {
Tensor tensor = Downcast<Tensor>(op->node);
Buffer buffer = GetOrAllocBuffer(tensor);
return AttrStmtNode::make(
buffer, op->attr_key, op->value, op->body);
} else {
return ret;
}
}
Stmt VisitStmt_(const RealizeNode* op) final {
Tensor tensor = Downcast<Operation>(op->func).output(op->value_index);
Buffer buffer = GetOrAllocBuffer(tensor);
auto ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<RealizeNode>();
return BufferRealize(buffer, op->bounds, op->condition, op->body);
}
Stmt VisitStmt_(const ProvideNode* op) final {
Tensor tensor = Downcast<Operation>(op->func).output(op->value_index);
Buffer buffer = GetBuffer(tensor);
auto ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<ProvideNode>();
return BufferStore(buffer, op->value, op->args);
}
PrimExpr VisitExpr_(const CallNode* op) final {
auto ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<CallNode>();
if (op->call_type == CallNode::Halide) {
Tensor tensor = Downcast<Operation>(op->func).output(op->value_index);
Buffer buffer = GetBuffer(tensor);
return tir::BufferLoad(buffer, op->args);
} else {
return ret;
}
}
private:
Buffer GetOrAllocBuffer(const Tensor& tensor) {
return GetBuffer(tensor, true);
}
Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) {
auto it = buffer_map_.find(tensor);
if (it != buffer_map_.end()) return it->second;
CHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor;
auto buffer = CreateBufferFor(tensor);
buffer_map_[tensor] = buffer;
return buffer;
}
// maps tensor to buffer.
std::unordered_map<Tensor, Buffer> buffer_map_;
};
PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
Stmt body,
Optional<Map<Tensor, Buffer>> extern_buffer_opt) {
std::unordered_map<Tensor, Buffer> extern_buffer;
if (extern_buffer_opt.defined()) {
auto v = extern_buffer_opt.value();
extern_buffer = std::unordered_map<Tensor, Buffer>(v.begin(), v.end());
}
Array<tir::Var> params;
Map<tir::Var, tir::Buffer> buffer_map;
for (auto var : arg_list) {
if (auto* n = var.as<tir::VarNode>()) {
params.push_back(GetRef<tir::Var>(n));
} else if (auto* n = var.as<te::TensorNode>()) {
te::Tensor tensor = GetRef<te::Tensor>(n);
CHECK(!extern_buffer.count(tensor));
tir::Buffer buffer = CreateBufferFor(tensor);
tir::Var bptr(buffer->name, DataType::Handle());
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
extern_buffer[tensor] = buffer;
} else {
tir::Buffer buffer = Downcast<tir::Buffer>(var);
tir::Var bptr(buffer->name, DataType::Handle());
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
}
}
body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body));
return tir::PrimFunc(params, body, VoidType(), buffer_map);
}
TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc")
.set_body_typed(SchedulePostProcToPrimFunc);
} // namespace te
} // namespace tvm
......@@ -646,6 +646,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferLoadNode*>(node.get());
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
p->Print(op->indices[i]);
if (i < op->indices.size() - 1) {
p->stream << ", ";
}
}
p->stream << "]";
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LetNode*>(node.get());
p->stream << "(let " << op->var << " = ";
......
......@@ -253,24 +253,14 @@ TVM_REGISTER_GLOBAL("tir.Realize")
.set_body_typed(RealizeNode::make);
Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
CHECK(bounds[i]->extent.defined());
CHECK(bounds[i]->min.dtype().is_scalar());
CHECK(bounds[i]->extent.dtype().is_scalar());
}
ObjectPtr<PrefetchNode> node = make_object<PrefetchNode>();
node->func = std::move(func);
node->value_index = value_index;
node->dtype = dtype;
node->bounds = std::move(bounds);
return Stmt(node);
Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
data_ = make_object<PrefetchNode>(buffer, bounds);
}
TVM_REGISTER_GLOBAL("tir.Prefetch")
.set_body_typed(PrefetchNode::make);
.set_body_typed([](Buffer buffer, Array<Range> bounds) {
return Prefetch(buffer, bounds);
});
SeqStmt::SeqStmt(Array<Stmt> seq) {
......@@ -326,6 +316,25 @@ TVM_REGISTER_GLOBAL("tir.BufferStore")
TVM_REGISTER_NODE_TYPE(BufferStoreNode);
BufferRealize::BufferRealize(Buffer buffer,
Array<Range> bounds,
PrimExpr condition,
Stmt body) {
data_ = make_object<BufferRealizeNode>(
buffer, bounds, condition, body);
}
TVM_REGISTER_GLOBAL("tir.BufferRealize")
.set_body_typed([](Buffer buffer,
Array<Range> bounds,
PrimExpr condition,
Stmt body) {
return BufferRealize(buffer, bounds, condition, body);
});
TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
// Printers
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -433,6 +442,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferStoreNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
p->Print(op->indices[i]);
if (i < op->indices.size() - 1) p->stream << ", ";
}
p->stream << "]";
p->stream << " = ";
p->Print(op->value);
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AllocateNode*>(node.get());
p->PrintIndent();
......@@ -459,6 +483,34 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferRealizeNode*>(node.get());
p->PrintIndent();
p->stream << "buffer_realize " << op->buffer->name << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) {
p->stream << "[";
p->Print(op->bounds[i]->min);
p->stream << ", ";
p->Print(op->bounds[i]->extent);
p->stream << "]";
if (i < op->bounds.size() - 1) p->stream << ", ";
}
p->stream << ")";
if (!is_one(op->condition)) {
p->stream << " if ";
p->Print(op->condition);
}
p->stream << " {\n";
p->indent += 2;
p->Print(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RealizeNode*>(node.get());
p->PrintIndent();
......@@ -493,7 +545,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const PrefetchNode*>(node.get());
p->PrintIndent();
p->stream << "prefetch " << op->func->func_name() << "(";
p->stream << "prefetch " << op->buffer << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) {
p->stream << "[";
p->Print(op->bounds[i]->min);
......@@ -503,9 +555,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
if (i < op->bounds.size() - 1) p->stream << ", ";
}
p->stream << ")";
if (op->func->num_outputs() != 1) {
p->stream << ".value[" << op->value_index << "]";
}
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......
......@@ -158,9 +158,19 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) {
}
void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
this->VisitExpr(op->value);
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
VisitArray(op->bounds, [this](const Range& r) {
this->VisitExpr(r->min);
this->VisitExpr(r->extent);
});
this->VisitExpr(op->condition);
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
this->VisitExpr(op->condition);
this->VisitStmt(op->then_case);
......@@ -336,16 +346,38 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
}
Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
PrimExpr value = this->VisitExpr(op->value);
Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
if (indices.same_as(op->indices)) {
if (value.same_as(op->value) &&
indices.same_as(op->indices)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->value = std::move(value);
n->indices = std::move(indices);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) {
Region bounds = Internal::Mutate(this, op->bounds);
PrimExpr condition = this->VisitExpr(op->condition);
Stmt body = this->VisitStmt(op->body);
if (bounds.same_as(op->bounds) &&
condition.same_as(op->condition) &&
body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->bounds = std::move(bounds);
n->condition = std::move(condition);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
Array<PrimExpr> args = Internal::Mutate(this, op->args);
PrimExpr value = this->VisitExpr(op->value);
......
......@@ -75,15 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute")
}
});
TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args.size() <= 3) {
*ret = StorageFlatten(args[0], args[1], args[2]);
} else {
*ret = StorageFlatten(args[0], args[1], args[2], args[3]);
}
});
TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
.set_body_typed
([](const Stmt& stmt,
......@@ -116,7 +107,6 @@ REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA);
REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform);
REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(VerifyCompactBuffer);
......
......@@ -21,9 +21,12 @@
* \file inject_prefetch.cc
*/
// Inject prefetch op in HalideIR
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/arith/bound.h>
#include <tvm/arith/analyzer.h>
#include <unordered_set>
......@@ -39,9 +42,9 @@ class PrefetchInjector : public StmtMutator {
Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
if (op && op->attr_key == attr::prefetch_scope) {
te::Tensor ts = Downcast<te::Tensor>(op->node);
Buffer buffer = Downcast<Buffer>(op->node);
CHECK_NE(loop_nest_.size(), 0U);
Domain domain = DomainTouched(op->body, ts, true, false);
Domain domain = DomainTouched(op->body, buffer, true, false);
Region region;
auto iter_var = loop_nest_.back().get();
......@@ -49,7 +52,7 @@ class PrefetchInjector : public StmtMutator {
for (Range r : domain) {
if (!r.defined()) {
LOG(WARNING) << "Cannot decide prefetch region for " << ts;
LOG(WARNING) << "Cannot decide prefetch region for " << buffer;
return op->body;
}
Range res(EvalSet(r, vectorized_).cover_range(none));
......@@ -58,7 +61,7 @@ class PrefetchInjector : public StmtMutator {
vectorized_.erase(iter_var);
Stmt prefetch = PrefetchNode::make(ts->op, ts->value_index, ts->dtype, region);
Stmt prefetch = Prefetch(buffer, region);
return SeqStmt({prefetch, op->body});
}
return ret;
......@@ -90,5 +93,22 @@ Stmt InjectPrefetch(Stmt stmt) {
return PrefetchInjector()(std::move(stmt));
}
namespace transform {
Pass InjectPrefetch() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = PrefetchInjector()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch")
.set_body_typed(InjectPrefetch);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -22,21 +22,25 @@ def test_domain_touched():
j = te.var('j')
n = tvm.runtime.convert(100)
m = te.var('m')
a = te.placeholder((n, m), name = 'a')
b = te.placeholder((n, m), name = 'b')
a = tvm.tir.decl_buffer((n, m), name='a')
b = tvm.tir.decl_buffer((n, m), name='b')
ir = tvm.tir.For(
i, 0, n, 0, 0,
tvm.tir.For(j, 0, m, 0, 0,
tvm.tir.Provide(
a.op,
0,
tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) +
tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0),
tvm.tir.BufferStore(
a,
tvm.tir.BufferLoad(b, [i - 1, j + 1]) +
tvm.tir.BufferLoad(a, [i - 1, j - 1]),
[i, j]
)
)
)
a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)
assert a_domain_r[0].min.value == -1
assert a_domain_r[0].extent.value == 100
assert a_domain_r[1].min.value == -1
......
......@@ -48,9 +48,9 @@ def test_split_uneven_unique_likely():
x, y = c.op.axis
sch = te.create_schedule(c.op)
xo, xi = sch[c].split(x, 5)
stmt = tvm.lower(sch, [a, b, c], simple_mode=True)
stmt = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse)
assert str(stmt.body.body.body).count("likely") == 1
if __name__ == "__main__":
test_lower_rfactor()
......
......@@ -365,7 +365,7 @@ def test_bind():
a = te.placeholder((8, 4), 'float32')
c = foo(a)
s = te.create_schedule(c.op)
ir = tvm.lower(s, [a, c], simple_mode=True)
ir = tvm.lower(s, [a, c])
func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
......@@ -517,7 +517,7 @@ def test_upstream():
c = te.compute((20, ), lambda x: a[x] + b[x])
d = upstream(c)
sch = te.create_schedule([c.op, d.op])
ir = tvm.lower(sch, [a, b, d], simple_mode=True)
ir = tvm.lower(sch, [a, b, d])
func = tvm.build(sch, [a, b, d])
assert(func)
......@@ -730,7 +730,7 @@ def test_schedule():
joo, joi = sch[c].split(jo, 4)
sch[c].vectorize(ji)
sch[c].reorder(ii, io, joo, joi, ji)
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
ir = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.tir.For)
......@@ -751,7 +751,7 @@ def test_schedule():
# Test fuse
sch = te.create_schedule(c.op)
sch[c].fuse(c.op.axis[0], c.op.axis[1])
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
ir = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.tir.For)
......
......@@ -283,7 +283,7 @@ def test_tensor_intrin_scalar_params():
# Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs
C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, C], simple_mode=True)
stmt = tvm.lower(s, [A, C])["main"].body
assert isinstance(stmt.body.body, tvm.tir.Evaluate)
assert len(stmt.body.body.value.args) == 5
assert str(stmt.body.body.value.args[3]) == "(i*i)"
......
......@@ -28,6 +28,9 @@ def test_schedule0():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[A, A1], stmt, None)
assert isinstance(func, tvm.tir.PrimFunc)
def test_schedule1():
......@@ -43,6 +46,10 @@ def test_schedule1():
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[A, A1], stmt, None)
assert isinstance(func, tvm.tir.PrimFunc)
def test_schedule2():
m = te.var('m')
......@@ -57,6 +64,9 @@ def test_schedule2():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[A, A2], stmt, None)
assert isinstance(func, tvm.tir.PrimFunc)
def test_schedule_scan():
......@@ -77,6 +87,7 @@ def test_schedule_scan():
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
def test_inline_multi_reduce():
def argmax_comp(x, y):
idx = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
......@@ -510,19 +521,19 @@ def test_local_stage_predicate():
return ret
# local vs. threadIdx
s = schedule(tx, "local")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
lowered_body = tvm.lower(s, [A, C])["main"].body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
# local vs. vthread
s = schedule(vx, "local")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
lowered_body = tvm.lower(s, [A, C])["main"].body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
# shared vs. blockIdx
s = schedule(by, "shared")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
lowered_body = tvm.lower(s, [A, C])["main"].body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
......@@ -548,7 +559,7 @@ def test_local_stage_predicate2():
s[AA].compute_at(s[C], ooc)
oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32)
s[AA].bind(iaa, thread_x)
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
lowered_body = tvm.lower(s, [A, C])["main"].body
def collect_visit(stmt, f):
ret = []
......
......@@ -128,7 +128,7 @@ def test_tensor_compute1():
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))
s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
stmt = tvm.lower(s, [A, B, C])["main"].body
assert isinstance(stmt.body, tvm.tir.Evaluate)
def test_tensor_compute2():
......@@ -171,7 +171,7 @@ def test_tensor_compute2():
lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k))
s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
stmt = tvm.lower(s, [A, B, C])["main"].body
assert isinstance(stmt.body.body[0], tvm.tir.Evaluate)
assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate)
......
......@@ -24,29 +24,6 @@ gpu_devices = ["cuda", "opencl", "metal", "vulkan"]
other_devices = ["llvm", "ext_dev"]
def lower(sch, args):
binds = {}
arg_list = []
for x in args:
if isinstance(x, te.tensor.Tensor):
buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
else:
raise ValueError("args must be Tensor, Buffer or Var")
sch = sch.normalize()
bounds = tvm.te.schedule.InferBound(sch)
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String("test"))
mod = tvm.IRModule({"test": f})
return mod
# All computations are bound.
# So VerifyMemory pass is expected to succeed.
#
......@@ -61,7 +38,7 @@ def test_verify_memory_all_bind():
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
mod = lower(s, [A, B])
mod = tvm.lower(s, [A, B])
for dev_type in gpu_devices + other_devices:
binded_mod = tvm.tir.transform.Apply(
......@@ -81,7 +58,7 @@ def test_verify_memory_not_bind():
# B is not bound to threads.
s = te.create_schedule(B.op)
mod = lower(s, [A, B])
mod = tvm.lower(s, [A, B])
for dev_type in gpu_devices:
binded_mod = tvm.tir.transform.Apply(
......@@ -111,7 +88,7 @@ def test_verify_memory_partially_bind():
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = lower(s, [A, B, C, D])
mod = tvm. lower(s, [A, B, C, D])
for dev_type in gpu_devices:
binded_mod = tvm.tir.transform.Apply(
......
......@@ -194,9 +194,9 @@ def test_stmt_constructor():
assert x.then_case.value.value == 11
assert x.else_case == nop
x = tvm.tir.Prefetch(None, 1, "float32", [])
b = tvm.tir.decl_buffer((1, 2))
x = tvm.tir.Prefetch(b, [])
assert isinstance(x, tvm.tir.Prefetch)
assert x.value_index == 1
if __name__ == "__main__":
......
......@@ -28,7 +28,6 @@ def test_for():
A[j] = A[j] + 2
body = ib.get()
print(body)
assert isinstance(body, tvm.tir.AttrStmt)
body = body.body
assert isinstance(body, tvm.tir.Allocate)
......@@ -59,14 +58,13 @@ def test_if():
assert body.else_case.index.value == 0
def test_prefetch():
A = te.placeholder((10, 20), name="A")
A = tvm.tir.decl_buffer((10, 20), name="A")
ib = tvm.tir.ir_builder.create()
n = te.size_var("n")
with ib.for_range(0, n, name="i") as i:
ib.emit(
tvm.tir.Prefetch(
A.op, A.value_index, A.dtype,
tvm.tir.Prefetch(A,
[tvm.ir.Range.make_by_min_extent(i+1, 2),
tvm.ir.Range.make_by_min_extent(0, 20)]))
body = ib.get()
......
......@@ -301,6 +301,10 @@ def test_buffer_load_store():
s = tvm.tir.BufferStore(b, 0.1, [0])
assert isinstance(s, tvm.tir.BufferStore)
s = tvm.tir.BufferRealize(b, [tvm.ir.Range(0, 1)],
True, tvm.tir.Evaluate(0))
assert isinstance(s, tvm.tir.BufferRealize)
def test_intimm_cond():
x = tvm.runtime.convert(1)
......
......@@ -26,9 +26,10 @@ def test_copy2d():
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def cb(src, dst, pad_before, pad_after, pad_value):
assert dst.strides[0] == l
assert dst.strides[1].value == 1
......@@ -36,7 +37,6 @@ def test_copy2d():
assert tuple(src.shape) == (m, l)
return tvm.tir.Evaluate(0)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
......@@ -51,9 +51,11 @@ def test_copy_pad():
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0
assert pad_before[0].value == 1
......@@ -63,7 +65,6 @@ def test_copy_pad():
assert pad_value.value == 1.0
return tvm.tir.Evaluate(0)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
......@@ -75,9 +76,11 @@ def test_single_point_test():
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0
assert tvm.tir.ir_pass.Simplify(dst.elem_offset).value == 0
......@@ -85,7 +88,6 @@ def test_single_point_test():
assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.tir.Evaluate(0)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
......@@ -105,11 +107,12 @@ def test_copy_pad_split():
s[Apad].pragma(s[Apad].op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.tir.ir_pass.Simplify(stmt)
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
mod = tvm.tir.transform.Simplify()(mod._move())
def cb(src, dst, pad_before, pad_after, pad_value):
assert(dst.elem_offset.value == 0)
assert_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1)
......@@ -121,12 +124,10 @@ def test_copy_pad_split():
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.tir.Evaluate(0)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
if __name__ == "__main__":
test_copy2d()
test_copy_pad()
......
......@@ -28,18 +28,16 @@ def test_makeapi():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.tir.decl_buffer(C.shape, C.dtype, name='C')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr({
"target": tvm.target.create("llvm"),
"global_symbol": "main",
}))(mod)
num_unpacked_args = 2
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr({
"global_symbol": "main",
"target": tvm.target.create("llvm")
}))
f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
assert(len(f.params) == 7)
......
......@@ -40,8 +40,11 @@ def lower_sch(sch, args, target_bits):
raise ValueError("args must be Tensor, Buffer or Var")
bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
return lower_stmt(arg_list, stmt, target_bits)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body
def test_basic():
......
......@@ -30,11 +30,14 @@ def test_flatten2():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
stmt = tvm.tir.ir_pass.Simplify(stmt)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[Ab, A2b], stmt, {A: Ab, A2: A2b})
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def test_flatten_prefetch():
A = te.placeholder((25, 100, 4), name = 'A')
......@@ -42,8 +45,14 @@ def test_flatten_prefetch():
i = te.size_var('i')
j = te.size_var('j')
region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]]
stmt = tvm.tir.Prefetch(A.op, 0, A.dtype, region)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: _A}, 64)
stmt = tvm.tir.Prefetch(_A, region)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[_A], stmt, {A: _A})
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
stmt = mod["main"].body
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert stmt.extent.value == 2
assert isinstance(stmt.body, tvm.tir.For)
......@@ -62,12 +71,15 @@ def test_flatten_storage_align():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
stmt = mod["main"].body
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert(stmt.body.extents[0].value == 17 * 8)
def test_flatten_double_buffer():
dtype = 'int64'
n = 100
......@@ -87,7 +99,13 @@ def test_flatten_double_buffer():
C[j] = B[j] + 1
stmt = ib.get()
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {}, 64)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A, C], stmt))
mod = tvm.tir.transform.StorageFlatten(64)(mod)
stmt = mod["main"].body
stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2)
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate)
......@@ -105,7 +123,7 @@ def test_flatten_double_buffer():
assert count[0] == 4
if __name__ == "__main__":
test_flatten_storage_align()
test_flatten2()
test_flatten_prefetch()
test_flatten_storage_align()
test_flatten_double_buffer()
test_flatten_prefetch()
......@@ -30,11 +30,11 @@ def test_storage_share():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
......@@ -166,11 +166,11 @@ def test_inplace_rule():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
......@@ -201,11 +201,10 @@ def test_storage_combine():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
......@@ -238,11 +237,9 @@ def test_storage_share_gpu():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A')
Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
......@@ -306,13 +303,11 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C')
Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb, Cc, Dd], stmt))
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
......@@ -398,17 +393,11 @@ def test_inplace_rule3():
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
B0a = tvm.tir.decl_buffer(B0.shape, B0.dtype, name='B0')
B1a = tvm.tir.decl_buffer(B1.shape, B1.dtype, name='B1')
B2a = tvm.tir.decl_buffer(B2.shape, B2.dtype, name='B2')
B3a = tvm.tir.decl_buffer(B3.shape, B3.dtype, name='B3')
B4a = tvm.tir.decl_buffer(B4.shape, B4.dtype, name='B4')
B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B3a, B4: B4a, B5: B5a, B: Bb}, 64)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[B0, B1, B2, B3, B4, B5, B], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([B0a, B1a, B2a, B3a, B4a, B5a, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
......@@ -547,7 +536,7 @@ def test_large_input():
c = te.compute(shape, lambda i, j: compute(a, b)[i, j])
c = te.compute(shape, lambda i, j: 1 + c[i, j])
s = te.create_schedule(c.op)
stmt = tvm.lower(s, [a, b, c], simple_mode=True)
stmt = tvm.lower(s, [a, b, c])["main"].body
def verify(n):
if isinstance(n, tvm.tir.Allocate):
assert n.extents[0].value == 268435456
......
......@@ -34,15 +34,15 @@ def test_thread_storage_sync():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
cuda_target = tvm.target.create("cuda")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab, A2b], stmt).with_attr({
"global_symbol": "test", "target": cuda_target}))
mod = tvm.tir.transform.Apply(lambda f: f.with_attr({
"global_symbol": "test", "target": cuda_target}))(mod._move())
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
......
......@@ -40,8 +40,6 @@ Before reading this tutorial, we assume readers have already known these topics
take a look at ``python/tvm/build_module.py`` to get some basics.
"""
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
......@@ -57,7 +55,7 @@ b = te.placeholder((n, ), name="b")
c = te.compute((n, ), lambda i: a[i] + b[i], name='c')
sch = te.create_schedule(c.op)
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
ir = tvm.lower(sch, [a, b, c])
print(ir)
######################################################################
......@@ -137,12 +135,8 @@ def vectorize(stmt):
# Glue to Lowering
# ----------------
# So far, we are done with writing this IR transformation pass. What we need to do next is to glue
# this pass to TVM's lower pass. We can first call this function directly as a sanity check.
# this pass to TVM's lower pass.
#
print(vectorize(ir))
#####################################################################
# In TVM, there is a property called ``BuildConfig``. You can use this property to customize your
# own lowering options. In this case, we inject the pass written above into the TVM standard lowering
# pass by feeding **a list of tuple** as argument to ``add_lower_pass``. "Tuple" indicates different
......@@ -160,7 +154,7 @@ print(vectorize(ir))
#
with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg:
print(tvm.lower(sch, [a, b, c], simple_mode=True))
print(tvm.lower(sch, [a, b, c]))
#####################################################################
# Quick View
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment