Unverified Commit c3511c5e by Tianqi Chen Committed by GitHub

[TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes. (#5372)

* [TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes.

te::Tensor is an useful object for tensor expression, but brings
un-necessary reverse dependency in TIR nodes such as Provide and Realize.

This PR is a first step to remove this dependency. We will use Buffer in all the places
where the te::Tensor was used. The rough correspondence are:

- Provide -> BufferStore
- Realize -> BufferRealize
- HalideCall -> BufferLoad.

After this change, we can not use IRModule of PrimFuncs cleanly to represent TIR
at any point of the optimizations. Buffer will serve as the abstraction for the TIR data
models to represent the intermediate storages and their constraints.

We still keep Realize/HalideCall and Provide as TIR nodes for now to make the change minimum.
Right after ScheduleOps, we call SchedulePostProcToPrimFunc to canonicalize the temporary IR
generated by TE(which contains these nodes) to the TIR.

The TIR optimizations are now mostly migrated to to the pass manager.
Followup PRs are needed to migrate the remaining few passes.

* Fix dev tutorial
parent a4902e05
...@@ -78,15 +78,15 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond, ...@@ -78,15 +78,15 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
/*! /*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement. * \brief Infer a regular domain that covers all the calls or provides within the given statement.
* \param body The given statement. * \param body The given statement.
* \param tensor The name of the calls or provides. * \param buffer The buffer to check the access info.
* \param consider_calls If calls (read) are considered. * \param consider_loads If loads are considered.
* \param consider_provides If provides (write) are considered. * \param consider_stores If stores are considered.
* \return The domain that covers all the calls or provides within the given statement. * \return The domain that covers all the calls or provides within the given statement.
*/ */
Domain DomainTouched(Stmt body, Domain DomainTouched(const Stmt& body,
const te::Tensor &tensor, const tir::Buffer& buffer,
bool consider_calls, bool consider_loads,
bool consider_provides); bool consider_stores);
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#define TVM_TE_SCHEDULE_PASS_H_ #define TVM_TE_SCHEDULE_PASS_H_
#include <tvm/te/schedule.h> #include <tvm/te/schedule.h>
#include <tvm/tir/function.h>
namespace tvm { namespace tvm {
namespace te { namespace te {
...@@ -55,6 +56,26 @@ Map<IterVar, Range> InferBound(const Schedule& sch); ...@@ -55,6 +56,26 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop); Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
/*! /*!
* \brief Postprocessing the Stmt generated by ScheduleOps to create
* a PrimFunc that can then be used for further TIR optimizations.
*
* Perform this translation before running any TIR optimizations.
*
* List of actions taken by the function:
* - Remove occurences of te::Tensor, te::Operation in the IR
* and replace them by corresponding IR nodes via tir::Buffer.
* - Add annotation of extern buffers using the buffer_map field
* in the PrimFunc type.
*
* \param arg_list Array of Tensor/Var/Buffer arguments to the function.
* \param body The body of the function.
* \param bindings potential Tensor to Buffer bindings for the Tensors in the body.
*/
PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
Stmt body,
Optional<Map<Tensor, Buffer>> bindings);
/*!
* \brief To automatically inline the element-wise operations. * \brief To automatically inline the element-wise operations.
* *
* \param sch The schedule to be inlined. * \param sch The schedule to be inlined.
......
...@@ -694,7 +694,10 @@ class CallNode : public PrimExprNode { ...@@ -694,7 +694,10 @@ class CallNode : public PrimExprNode {
ExternCPlusPlus = 1, ExternCPlusPlus = 1,
/*! \brief Extern "C" without side-effect. */ /*! \brief Extern "C" without side-effect. */
PureExtern = 2, PureExtern = 2,
/*! \brief Halide-style call, evaluates func(args). */ /*!
* \brief Halide-style call, evaluates func(args).
* \note Deprecated, move to BufferLoad in the future.
*/
Halide = 3, Halide = 3,
/*! \brief Intrinsic functions. */ /*! \brief Intrinsic functions. */
Intrinsic = 4, Intrinsic = 4,
...@@ -707,9 +710,15 @@ class CallNode : public PrimExprNode { ...@@ -707,9 +710,15 @@ class CallNode : public PrimExprNode {
Array<PrimExpr> args; Array<PrimExpr> args;
/*! \brief Type of calls. */ /*! \brief Type of calls. */
CallType call_type; CallType call_type;
/*! \brief The function to be called. */ /*!
* \brief The function to be called.
* \note Deprecated, move to BufferLoad in the future.
*/
FunctionRef func; FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */ /*!
* \brief The output value index if func's value is a tuple.
* \note Deprecated, move to BufferLoad in the future.
*/
int value_index{0}; int value_index{0};
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
......
...@@ -165,22 +165,6 @@ Stmt Inline(Stmt stmt, ...@@ -165,22 +165,6 @@ Stmt Inline(Stmt stmt,
PrimExpr body); PrimExpr body);
/*! /*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
*
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \param cache_line_size The size of CPU cache line.
* \param create_bound_attribute Whether to create bound attributes.
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<te::Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);
/*!
* \brief Try to modify the AST to support TensorCore * \brief Try to modify the AST to support TensorCore
* *
* \param stmt The stmt to be trasnformed. * \param stmt The stmt to be trasnformed.
...@@ -203,13 +187,6 @@ Stmt RewriteForTensorCore(Stmt stmt, ...@@ -203,13 +187,6 @@ Stmt RewriteForTensorCore(Stmt stmt,
bool VerifyCompactBuffer(Stmt stmt); bool VerifyCompactBuffer(Stmt stmt);
/*! /*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectPrefetch(Stmt stmt);
/*!
* \brief Decorate the stmt with a device scope, this is helpful for * \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks. * hardware accelerator without thread blocks.
* *
......
...@@ -248,7 +248,6 @@ class StoreNode : public StmtNode { ...@@ -248,7 +248,6 @@ class StoreNode : public StmtNode {
* \endcode * \endcode
* \sa BufferLoad * \sa BufferLoad
*/ */
class BufferStore;
class BufferStoreNode : public StmtNode { class BufferStoreNode : public StmtNode {
public: public:
/*! \brief The buffer variable. */ /*! \brief The buffer variable. */
...@@ -281,6 +280,10 @@ class BufferStoreNode : public StmtNode { ...@@ -281,6 +280,10 @@ class BufferStoreNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode); TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
}; };
/*!
* \brief Managed reference to BufferStoreNode.
* \sa BufferStoreNode
*/
class BufferStore : public Stmt { class BufferStore : public Stmt {
public: public:
TVM_DLL explicit BufferStore(Buffer buffer, TVM_DLL explicit BufferStore(Buffer buffer,
...@@ -290,7 +293,79 @@ class BufferStore : public Stmt { ...@@ -290,7 +293,79 @@ class BufferStore : public Stmt {
}; };
/*! /*!
* \brief Annotate the region where the buffer need to
* be read and write in the body.
* We only need to allocate the space for the corresponding region.
*
* \note There should be at most one BufferRealize for each buffer.
* BufferRealize is not necessary for external buffers,
* since they are assumed to be fully allocated.
*
* \sa BufferLoad, BufferStore
*/
class BufferRealizeNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Buffer buffer;
/*! \brief Bounds to be realized */
Array<Range> bounds;
/*! \brief Only realize if condition holds. */
PrimExpr condition;
/*! \brief The body of realization. */
Stmt body;
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
}
bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
return
equal(buffer, other->buffer) &&
equal(bounds, other->bounds) &&
equal(condition, other->condition) &&
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
}
BufferRealizeNode() = default;
BufferRealizeNode(Buffer buffer,
Array<Range> bounds,
PrimExpr condition,
Stmt body)
: buffer(buffer), bounds(bounds),
condition(condition), body(body) {}
static constexpr const char* _type_key = "BufferRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode);
};
/*!
* \brief Managed reference to BufferRealizeNode.
* \sa BufferRealizeNode
*/
class BufferRealize : public Stmt {
public:
TVM_DLL explicit BufferRealize(Buffer buffer,
Array<Range> bounds,
PrimExpr condition,
Stmt body);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode);
};
/*!
* \brief Store value into mult-dimensional array defined by func. * \brief Store value into mult-dimensional array defined by func.
*
* \note Deprecated, move to BufferStore in the future.
*/ */
class ProvideNode : public StmtNode { class ProvideNode : public StmtNode {
public: public:
...@@ -430,6 +505,8 @@ class FreeNode : public StmtNode { ...@@ -430,6 +505,8 @@ class FreeNode : public StmtNode {
/*! /*!
* \brief Annotate the bounds where func need to be written and read in body. * \brief Annotate the bounds where func need to be written and read in body.
* We will need to allocate space for the corresponding regions. * We will need to allocate space for the corresponding regions.
*
* \note Deprecated, move to BufferRealize in the future.
*/ */
class RealizeNode : public StmtNode { class RealizeNode : public StmtNode {
public: public:
...@@ -747,51 +824,51 @@ class ForNode : public StmtNode { ...@@ -747,51 +824,51 @@ class ForNode : public StmtNode {
}; };
/*! /*!
* \brief A prefetch hint of func. * \brief A prefetch hint for abuffer
*/ */
class PrefetchNode : public StmtNode { class PrefetchNode : public StmtNode {
public: public:
/*! \brief The function to be prefetched. */ /*! \brief The function to be prefetched. */
FunctionRef func; Buffer buffer;
/*! \brief The output value index if func's value is a tuple. */
int value_index;
/*! \brief The data type of the array. */
DataType dtype;
/*! \brief Bounds to be prefetched. */ /*! \brief Bounds to be prefetched. */
Region bounds; Array<Range> bounds;
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func); v->Visit("buffer", &buffer);
v->Visit("value_index", &value_index);
v->Visit("dtype", &dtype);
v->Visit("bounds", &bounds); v->Visit("bounds", &bounds);
} }
bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const { bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
return return
equal(func, other->func) && equal(buffer, other->buffer) &&
equal(value_index, other->value_index) &&
equal(dtype, other->dtype) &&
equal(bounds, other->bounds); equal(bounds, other->bounds);
} }
void SHashReduce(SHashReducer hash_reduce) const { void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func); hash_reduce(buffer);
hash_reduce(value_index);
hash_reduce(dtype);
hash_reduce(bounds); hash_reduce(bounds);
} }
TVM_DLL static Stmt make(FunctionRef func, PrefetchNode() = default;
int value_index, PrefetchNode(Buffer buffer, Array<Range> bounds)
DataType dtype, : buffer(buffer), bounds(bounds) {}
Region bounds);
static constexpr const char* _type_key = "Prefetch"; static constexpr const char* _type_key = "Prefetch";
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
}; };
/*! /*!
* \brief Managed reference to PrefetchNode.
* \sa PrefetchNode
*/
class Prefetch : public Stmt {
public:
TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
};
/*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor. * \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/ */
struct TensorKey { struct TensorKey {
......
...@@ -92,6 +92,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { ...@@ -92,6 +92,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
...@@ -121,6 +122,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { ...@@ -121,6 +122,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
return vtable; return vtable;
} }
}; };
...@@ -154,6 +157,7 @@ class TVM_DLL StmtVisitor : ...@@ -154,6 +157,7 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
void VisitStmt_(const FreeNode* op) override; void VisitStmt_(const FreeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProvideNode* op) override; void VisitStmt_(const ProvideNode* op) override;
...@@ -248,6 +252,7 @@ class TVM_DLL StmtMutator : ...@@ -248,6 +252,7 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Stmt VisitStmt_(const FreeNode* op) override; Stmt VisitStmt_(const FreeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProvideNode* op) override; Stmt VisitStmt_(const ProvideNode* op) override;
......
...@@ -58,6 +58,27 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< ...@@ -58,6 +58,27 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const std::string& name, const std::string& name,
const tvm::Array<runtime::String>& required); const tvm::Array<runtime::String>& required);
/*!
* \brief Inject prefetch instructions into stmt.
*
* \return The pass.
*/
TVM_DLL Pass InjectPrefetch();
// TODO(tvm-team): consolidate configs to the PassContext
/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
*
* \param cache_line_size The size of CPU cache line.
* \param create_bound_attribute Whether to create bound attributes.
*
* \return The Pass
*/
TVM_DLL Pass StorageFlatten(int cache_line_size,
bool create_bound_attribute = false);
/*! /*!
* \brief Inject copy intrinsics with optional pad. * \brief Inject copy intrinsics with optional pad.
* *
......
...@@ -31,7 +31,6 @@ import numpy as np ...@@ -31,7 +31,6 @@ import numpy as np
import tvm._ffi import tvm._ffi
from tvm import target as _target from tvm import target as _target
from tvm.tir import ir_pass
from tvm.te import schedule from tvm.te import schedule
from tvm.driver import build_module from tvm.driver import build_module
...@@ -46,10 +45,12 @@ def ana_lower(sch, args, ...@@ -46,10 +45,12 @@ def ana_lower(sch, args,
# Phase 0 # Phase 0
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds, True) stmt = schedule.ScheduleOps(sch, bounds, True)
stmt = ir_pass.StorageFlatten(stmt, binds, 64) func = schedule.SchedulePostProcToPrimFunc(args, stmt, None)
stmt = ir_pass.CanonicalSimplify(stmt) mod = tvm.IRModule.from_expr(func._move())
mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
mod = tvm.tir.transform.Simplify()(mod._move())
assert simple_mode assert simple_mode
return stmt return mod["main"].body
try: try:
_get_buffer_curve_sample_flatten = tvm._ffi.get_global_func( _get_buffer_curve_sample_flatten = tvm._ffi.get_global_func(
......
...@@ -85,7 +85,8 @@ def get_binds(args, compact=False, binds=None): ...@@ -85,7 +85,8 @@ def get_binds(args, compact=False, binds=None):
def form_body(sch): def form_body(sch):
"""According to the given schedule, form the raw body """According to the given schedule, form a function.
Parameters Parameters
---------- ----------
sch : tvm.te.schedule.Schedule sch : tvm.te.schedule.Schedule
...@@ -99,13 +100,31 @@ def form_body(sch): ...@@ -99,13 +100,31 @@ def form_body(sch):
sch = sch.normalize() sch = sch.normalize()
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt return stmt
def _wrap_as_prim_func_pass(flist, name):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def _transform(func, *_):
stmt = func.body
for f in flist:
stmt = f(stmt)
# create a new function with updated body.
return tvm.tir.PrimFunc(func.params,
stmt,
func.ret_type,
func.buffer_map,
func.attrs)
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)
def lower(sch, def lower(sch,
args, args,
name="default_function", name="main",
binds=None, binds=None,
simple_mode=False): simple_mode=False):
"""Lowering step before build into target. """Lowering step before build into target.
...@@ -154,56 +173,57 @@ def lower(sch, ...@@ -154,56 +173,57 @@ def lower(sch,
compact = ir_pass.VerifyCompactBuffer(stmt) compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds) binds, arg_list = get_binds(args, compact, binds)
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
# Start the new style pass manager.
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
func = func.with_attr("global_symbol", name)
if cfg.restricted_func:
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: func})
# Phase 1 # Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) pass_list = [
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) tvm.tir.transform.InjectPrefetch(),
stmt = ir_pass.NarrowDataType(stmt, 32) tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
stmt = ir_pass.CanonicalSimplify(stmt) tvm.tir.transform.NarrowDataType(32),
for f in lower_phase1: tvm.tir.transform.Simplify(),
stmt = f(stmt) _wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
]
# Phase 2 # Phase 2
if not simple_mode: if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) pass_list += [(tvm.tir.transform.LoopPartition(cfg.partition_const_loop))]
if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt) pass_list += [
else: tvm.tir.transform.VectorizeLoop(not cfg.disable_vectorize),
stmt = ir_pass.VectorizeLoop(stmt) tvm.tir.transform.InjectVirtualThread(),
stmt = ir_pass.InjectVirtualThread(stmt) tvm.tir.transform.InjectDoubleBuffer(cfg.double_buffer_split_loop),
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) tvm.tir.transform.StorageRewrite(),
stmt = ir_pass.StorageRewrite(stmt) tvm.tir.transform.UnrollLoop(
stmt = ir_pass.UnrollLoop(
stmt,
cfg.auto_unroll_max_step, cfg.auto_unroll_max_step,
cfg.auto_unroll_max_depth, cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent, cfg.auto_unroll_max_extent,
cfg.unroll_explicit) cfg.unroll_explicit),
_wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
for f in lower_phase2: ]
stmt = f(stmt)
# Phase 3 # Phase 3
stmt = ir_pass.Simplify(stmt) pass_list += [
stmt = ir_pass.RemoveNoOp(stmt) tvm.tir.transform.Simplify(),
if not cfg.disable_select_rewriting: tvm.tir.transform.RemoveNoOp(),
stmt = ir_pass.RewriteUnsafeSelect(stmt) ]
for f in lower_phase3: if not cfg.disable_select_rewriting:
stmt = f(stmt) pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")]
# Instrument BoundCheckers # Instrument BoundCheckers
if cfg.instrument_bound_checkers: if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt) pass_list += [tvm.tir.transform.InstrumentBoundCheckers()]
if simple_mode: optimize = tvm.transform.Sequential(pass_list)
return stmt mod = optimize(mod)
f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
if cfg.restricted_func:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return mod return mod
......
...@@ -157,11 +157,6 @@ class Sequential(Pass): ...@@ -157,11 +157,6 @@ class Sequential(Pass):
"""A pass that works on a sequence of pass objects. Multiple passes can be """A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class. executed sequentially using this class.
Some typical usage of the sequential pass are:
1. Users provide a list of passes for optimization.
2. Only an optimization level is provided so that the backend system has
to glob all passes at this level and below to perform the optimizations.
Note that users can also provide a series of passes that they don't want to Note that users can also provide a series of passes that they don't want to
apply when running a sequential pass. Pass dependency will be resolved in apply when running a sequential pass. Pass dependency will be resolved in
the backend as well. the backend as well.
...@@ -173,6 +168,9 @@ class Sequential(Pass): ...@@ -173,6 +168,9 @@ class Sequential(Pass):
opt_level : Optional[int] opt_level : Optional[int]
The optimization level of this sequential pass. The optimization level of this sequential pass.
The opt_level of a default sequential pass is set to 0.
Note that some of the passes within the Sequantial may still not be executed
if their opt_level is higher than the provided opt_level.
name : Optional[str] name : Optional[str]
The name of the sequential pass. The name of the sequential pass.
......
...@@ -28,7 +28,8 @@ from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let ...@@ -28,7 +28,8 @@ from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar, Any from .expr import IterVar, Any
from .stmt import Stmt, LetStmt, AssertStmt, For from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt from .stmt import BufferStore, BufferRealize, Store, Provide, Allocate, AttrStmt
from .stmt import Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .function import PrimFunc from .function import PrimFunc
......
...@@ -161,6 +161,29 @@ class BufferStore(Stmt): ...@@ -161,6 +161,29 @@ class BufferStore(Stmt):
@tvm._ffi.register_object @tvm._ffi.register_object
class BufferRealize(Stmt):
"""Buffer realize node.
Parameters
----------
buffer : Buffer
The buffer.
bounds : List[Range]
The value we to be stored.
condition : PrimExpr
The realize condition.
body : Stmt
The body of the statement.
"""
def __init__(self, buffer, bounds, condition, body):
self.__init_handle_by_constructor__(
_ffi_api.BufferRealize, buffer, bounds, condition, body)
@tvm._ffi.register_object
class Provide(Stmt): class Provide(Stmt):
"""Provide node. """Provide node.
...@@ -348,21 +371,15 @@ class Prefetch(Stmt): ...@@ -348,21 +371,15 @@ class Prefetch(Stmt):
Parameters Parameters
---------- ----------
func : Operation buffer : Buffer
The operation to create the function. The buffer to be prefetched.
value_index : int
The output value index
dtype : str
The data type to be prefetched.
bounds : list of Range bounds : list of Range
The bounds to be prefetched. The bounds to be prefetched.
""" """
def __init__(self, func, value_index, dtype, bounds): def __init__(self, buffer, bounds):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_ffi_api.Prefetch, func, value_index, dtype, bounds) _ffi_api.Prefetch, buffer, bounds)
def stmt_seq(*args): def stmt_seq(*args):
......
...@@ -60,6 +60,38 @@ def Filter(fcond): ...@@ -60,6 +60,38 @@ def Filter(fcond):
return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
def InjectPrefetch():
"""Inject prefetch instructions into stmt.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectPrefetch()
def StorageFlatten(cache_line_size, create_bound_attribute=False):
"""Flatten the multi-dimensional read/write to 1D.
Parameters
----------
cache_line_size: int
The size of CPU cache line.
create_bound_attribute:
Whether to create bound attributes.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute)
def InjectCopyIntrin(pragma_key, fintrin): def InjectCopyIntrin(pragma_key, fintrin):
"""Inject virtual thread loops. """Inject virtual thread loops.
......
...@@ -36,10 +36,14 @@ namespace arith { ...@@ -36,10 +36,14 @@ namespace arith {
using namespace tir; using namespace tir;
// Find Read region of the tensor in the stmt. // Find Read region of the tensor in the stmt.
class FuncTouchedDomain final : public StmtExprVisitor { class BufferTouchedDomain final : public StmtExprVisitor {
public: public:
FuncTouchedDomain(const te::Tensor &tensor, bool consider_calls, bool consider_provides) BufferTouchedDomain(const Buffer &buffer,
: tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {} bool consider_loads,
bool consider_stores)
: buffer_(buffer),
consider_loads_(consider_loads),
consider_stores_(consider_stores) {}
Domain Find(const Stmt& stmt) { Domain Find(const Stmt& stmt) {
operator()(stmt); operator()(stmt);
...@@ -80,18 +84,16 @@ class FuncTouchedDomain final : public StmtExprVisitor { ...@@ -80,18 +84,16 @@ class FuncTouchedDomain final : public StmtExprVisitor {
} }
} }
void VisitExpr_(const CallNode* op) final { void VisitExpr_(const BufferLoadNode* op) final {
if (consider_calls_ && tensor_->op.same_as(op->func) if (consider_loads_ && buffer_.same_as(op->buffer)) {
&& tensor_->value_index == op->value_index) { Touch(op->indices);
Touch(op->args);
} }
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
void VisitStmt_(const ProvideNode* op) final { void VisitStmt_(const BufferStoreNode* op) final {
if (consider_provides_ && tensor_->op.same_as(op->func) if (consider_stores_ && buffer_.same_as(op->buffer)) {
&& tensor_->value_index == op->value_index) { Touch(op->indices);
Touch(op->args);
} }
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
...@@ -106,17 +108,17 @@ class FuncTouchedDomain final : public StmtExprVisitor { ...@@ -106,17 +108,17 @@ class FuncTouchedDomain final : public StmtExprVisitor {
} }
} }
const te::Tensor &tensor_; const Buffer &buffer_;
bool consider_calls_, consider_provides_; bool consider_loads_, consider_stores_;
std::vector<std::vector<IntSet> > bounds_; std::vector<std::vector<IntSet> > bounds_;
std::unordered_map<const VarNode*, IntSet> dom_map_; std::unordered_map<const VarNode*, IntSet> dom_map_;
}; };
Domain DomainTouched(Stmt stmt, Domain DomainTouched(const Stmt& stmt,
const te::Tensor &tensor, const Buffer& buffer,
bool consider_calls, bool consider_loads,
bool consider_provides) { bool consider_stores) {
return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt); return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt);
} }
TVM_REGISTER_GLOBAL("arith.DomainTouched") TVM_REGISTER_GLOBAL("arith.DomainTouched")
......
...@@ -130,35 +130,6 @@ transform::Pass Filter(FCond fcond) { ...@@ -130,35 +130,6 @@ transform::Pass Filter(FCond fcond) {
} }
IRModule BuildIRModule(const Array<ObjectRef>& out_arg_list,
tir::Stmt stmt,
const std::string& name,
const BuildConfig& config) {
Array<tir::Var> params;
Map<tir::Var, tir::Buffer> buffer_map;
for (auto var : out_arg_list) {
if (auto* n = var.as<tir::VarNode>()) {
params.push_back(GetRef<tir::Var>(n));
} else {
tir::Buffer buffer = Downcast<tir::Buffer>(var);
tir::Var bptr(buffer->name, DataType::Handle());
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
}
}
auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map);
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
}
return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
}
IRModule lower(te::Schedule sch, IRModule lower(te::Schedule sch,
const Array<te::Tensor>& args, const Array<te::Tensor>& args,
const std::string& name, const std::string& name,
...@@ -168,23 +139,31 @@ IRModule lower(te::Schedule sch, ...@@ -168,23 +139,31 @@ IRModule lower(te::Schedule sch,
sch = sch.normalize(); sch = sch.normalize();
// Phase 0 // Before TIR transformation.
auto bounds = te::InferBound(sch); auto bounds = te::InferBound(sch);
auto stmt = te::ScheduleOps(sch, bounds, false); auto stmt = te::ScheduleOps(sch, bounds, false);
stmt = tir::InjectPrefetch(stmt);
bool compact = tir::VerifyCompactBuffer(stmt); bool compact = tir::VerifyCompactBuffer(stmt);
Map<te::Tensor, tir::Buffer> out_binds; Map<te::Tensor, tir::Buffer> out_binds;
GetBinds(args, compact, binds, &out_binds, &out_arg_list, config); GetBinds(args, compact, binds, &out_binds, &out_arg_list, config);
// Phase 1 // build the function
stmt = tir::StorageFlatten(stmt, out_binds, 64, tir::PrimFunc f = te::SchedulePostProcToPrimFunc(
config->instrument_bound_checkers); out_arg_list, std::move(stmt), out_binds);
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
}
// convert to IRModule. auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
auto mod = BuildIRModule(out_arg_list, stmt, name, config);
auto pass_list = Array<tvm::transform::Pass>(); auto pass_list = Array<tvm::transform::Pass>();
// Phase 0
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(
tir::transform::StorageFlatten(64, config->instrument_bound_checkers));
// Phase 1
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop)); pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop));
pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize)); pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize));
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file schedule_postproc_to_primfunc.cc
*
* \brief Translate the function body generated by ScheduleOps
* with te related dialects that incorporates Tensor
* into the Stmts to a PrimFunc.
*
* Perform this translation before running any TIR optimizations.
*
* Rationale: The body generated by ScheduleOps is not
* a formal PrimFunc and cannot be used for further optimization.
* This function canonicalize that body and creates a formal PrimFunc.
*
* List of actions taken by the function:
* - Remove occurences of te::Tensor, te::Operation in the IR
* and replace them by corresponding IR nodes via tir::Buffer.
* - Add annotation of extern buffers using the buffer_map field
* in the PrimFunc type.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/te/operation.h>
#include <utility>
#include <unordered_map>
namespace tvm {
namespace te {
// create a buffer for tensor.
Buffer CreateBufferFor(const Tensor& tensor) {
std::string name = tensor->op->name;
if (tensor->op->num_outputs() != 1) {
name += ".v" + std::to_string(tensor->value_index);
}
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name);
return buffer;
}
// A remapper that maps tensor to buffer
class TensorToBufferMapper : public StmtExprMutator {
public:
explicit TensorToBufferMapper(std::unordered_map<Tensor, Buffer> buffer_map)
: buffer_map_(buffer_map) {
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
auto ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
// TODO(tvm-team): remove realize_scope, turn the info into
// Buffer's scope field in this pass.
if (op->attr_key == tir::attr::realize_scope ||
op->attr_key == tir::attr::double_buffer_scope) {
Stmt body = op->body;
Operation operation = Downcast<Operation>(op->node);
for (int i = operation->num_outputs(); i != 0; --i) {
Buffer buffer = GetOrAllocBuffer(operation.output(i - 1));
body = AttrStmtNode::make(
buffer, op->attr_key, op->value, body);
}
return body;
} else if (op->attr_key == tir::attr::buffer_bind_scope) {
Array<ObjectRef> tuple = Downcast<Array<ObjectRef> >(op->node);
Tensor tensor = Downcast<Tensor>(tuple[1]);
return AttrStmtNode::make(
Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)},
op->attr_key, op->value, op->body);
} else if (op->attr_key == tir::attr::buffer_dim_align||
op->attr_key == tir::attr::prefetch_scope) {
Tensor tensor = Downcast<Tensor>(op->node);
Buffer buffer = GetOrAllocBuffer(tensor);
return AttrStmtNode::make(
buffer, op->attr_key, op->value, op->body);
} else {
return ret;
}
}
Stmt VisitStmt_(const RealizeNode* op) final {
Tensor tensor = Downcast<Operation>(op->func).output(op->value_index);
Buffer buffer = GetOrAllocBuffer(tensor);
auto ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<RealizeNode>();
return BufferRealize(buffer, op->bounds, op->condition, op->body);
}
Stmt VisitStmt_(const ProvideNode* op) final {
Tensor tensor = Downcast<Operation>(op->func).output(op->value_index);
Buffer buffer = GetBuffer(tensor);
auto ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<ProvideNode>();
return BufferStore(buffer, op->value, op->args);
}
PrimExpr VisitExpr_(const CallNode* op) final {
auto ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<CallNode>();
if (op->call_type == CallNode::Halide) {
Tensor tensor = Downcast<Operation>(op->func).output(op->value_index);
Buffer buffer = GetBuffer(tensor);
return tir::BufferLoad(buffer, op->args);
} else {
return ret;
}
}
private:
Buffer GetOrAllocBuffer(const Tensor& tensor) {
return GetBuffer(tensor, true);
}
Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) {
auto it = buffer_map_.find(tensor);
if (it != buffer_map_.end()) return it->second;
CHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor;
auto buffer = CreateBufferFor(tensor);
buffer_map_[tensor] = buffer;
return buffer;
}
// maps tensor to buffer.
std::unordered_map<Tensor, Buffer> buffer_map_;
};
PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
Stmt body,
Optional<Map<Tensor, Buffer>> extern_buffer_opt) {
std::unordered_map<Tensor, Buffer> extern_buffer;
if (extern_buffer_opt.defined()) {
auto v = extern_buffer_opt.value();
extern_buffer = std::unordered_map<Tensor, Buffer>(v.begin(), v.end());
}
Array<tir::Var> params;
Map<tir::Var, tir::Buffer> buffer_map;
for (auto var : arg_list) {
if (auto* n = var.as<tir::VarNode>()) {
params.push_back(GetRef<tir::Var>(n));
} else if (auto* n = var.as<te::TensorNode>()) {
te::Tensor tensor = GetRef<te::Tensor>(n);
CHECK(!extern_buffer.count(tensor));
tir::Buffer buffer = CreateBufferFor(tensor);
tir::Var bptr(buffer->name, DataType::Handle());
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
extern_buffer[tensor] = buffer;
} else {
tir::Buffer buffer = Downcast<tir::Buffer>(var);
tir::Var bptr(buffer->name, DataType::Handle());
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
}
}
body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body));
return tir::PrimFunc(params, body, VoidType(), buffer_map);
}
TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc")
.set_body_typed(SchedulePostProcToPrimFunc);
} // namespace te
} // namespace tvm
...@@ -646,6 +646,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -646,6 +646,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferLoadNode*>(node.get());
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
p->Print(op->indices[i]);
if (i < op->indices.size() - 1) {
p->stream << ", ";
}
}
p->stream << "]";
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LetNode*>(node.get()); auto* op = static_cast<const LetNode*>(node.get());
p->stream << "(let " << op->var << " = "; p->stream << "(let " << op->var << " = ";
......
...@@ -253,24 +253,14 @@ TVM_REGISTER_GLOBAL("tir.Realize") ...@@ -253,24 +253,14 @@ TVM_REGISTER_GLOBAL("tir.Realize")
.set_body_typed(RealizeNode::make); .set_body_typed(RealizeNode::make);
Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) { Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
for (size_t i = 0; i < bounds.size(); ++i) { data_ = make_object<PrefetchNode>(buffer, bounds);
CHECK(bounds[i]->min.defined());
CHECK(bounds[i]->extent.defined());
CHECK(bounds[i]->min.dtype().is_scalar());
CHECK(bounds[i]->extent.dtype().is_scalar());
}
ObjectPtr<PrefetchNode> node = make_object<PrefetchNode>();
node->func = std::move(func);
node->value_index = value_index;
node->dtype = dtype;
node->bounds = std::move(bounds);
return Stmt(node);
} }
TVM_REGISTER_GLOBAL("tir.Prefetch") TVM_REGISTER_GLOBAL("tir.Prefetch")
.set_body_typed(PrefetchNode::make); .set_body_typed([](Buffer buffer, Array<Range> bounds) {
return Prefetch(buffer, bounds);
});
SeqStmt::SeqStmt(Array<Stmt> seq) { SeqStmt::SeqStmt(Array<Stmt> seq) {
...@@ -326,6 +316,25 @@ TVM_REGISTER_GLOBAL("tir.BufferStore") ...@@ -326,6 +316,25 @@ TVM_REGISTER_GLOBAL("tir.BufferStore")
TVM_REGISTER_NODE_TYPE(BufferStoreNode); TVM_REGISTER_NODE_TYPE(BufferStoreNode);
BufferRealize::BufferRealize(Buffer buffer,
Array<Range> bounds,
PrimExpr condition,
Stmt body) {
data_ = make_object<BufferRealizeNode>(
buffer, bounds, condition, body);
}
TVM_REGISTER_GLOBAL("tir.BufferRealize")
.set_body_typed([](Buffer buffer,
Array<Range> bounds,
PrimExpr condition,
Stmt body) {
return BufferRealize(buffer, bounds, condition, body);
});
TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
// Printers // Printers
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -433,6 +442,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -433,6 +442,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferStoreNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
p->Print(op->indices[i]);
if (i < op->indices.size() - 1) p->stream << ", ";
}
p->stream << "]";
p->stream << " = ";
p->Print(op->value);
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AllocateNode*>(node.get()); auto* op = static_cast<const AllocateNode*>(node.get());
p->PrintIndent(); p->PrintIndent();
...@@ -459,6 +483,34 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -459,6 +483,34 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferRealizeNode*>(node.get());
p->PrintIndent();
p->stream << "buffer_realize " << op->buffer->name << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) {
p->stream << "[";
p->Print(op->bounds[i]->min);
p->stream << ", ";
p->Print(op->bounds[i]->extent);
p->stream << "]";
if (i < op->bounds.size() - 1) p->stream << ", ";
}
p->stream << ")";
if (!is_one(op->condition)) {
p->stream << " if ";
p->Print(op->condition);
}
p->stream << " {\n";
p->indent += 2;
p->Print(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RealizeNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<RealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RealizeNode*>(node.get()); auto* op = static_cast<const RealizeNode*>(node.get());
p->PrintIndent(); p->PrintIndent();
...@@ -493,7 +545,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -493,7 +545,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const PrefetchNode*>(node.get()); auto* op = static_cast<const PrefetchNode*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << "prefetch " << op->func->func_name() << "("; p->stream << "prefetch " << op->buffer << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) { for (size_t i = 0; i < op->bounds.size(); ++i) {
p->stream << "["; p->stream << "[";
p->Print(op->bounds[i]->min); p->Print(op->bounds[i]->min);
...@@ -503,9 +555,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -503,9 +555,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
if (i < op->bounds.size() - 1) p->stream << ", "; if (i < op->bounds.size() - 1) p->stream << ", ";
} }
p->stream << ")"; p->stream << ")";
if (op->func->num_outputs() != 1) {
p->stream << ".value[" << op->value_index << "]";
}
}); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......
...@@ -158,9 +158,19 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) { ...@@ -158,9 +158,19 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) {
} }
void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
this->VisitExpr(op->value);
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
} }
void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
VisitArray(op->bounds, [this](const Range& r) {
this->VisitExpr(r->min);
this->VisitExpr(r->extent);
});
this->VisitExpr(op->condition);
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
this->VisitExpr(op->condition); this->VisitExpr(op->condition);
this->VisitStmt(op->then_case); this->VisitStmt(op->then_case);
...@@ -336,16 +346,38 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) { ...@@ -336,16 +346,38 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
} }
Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
PrimExpr value = this->VisitExpr(op->value);
Array<PrimExpr> indices = Internal::Mutate(this, op->indices); Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
if (indices.same_as(op->indices)) {
if (value.same_as(op->value) &&
indices.same_as(op->indices)) {
return GetRef<Stmt>(op); return GetRef<Stmt>(op);
} else { } else {
auto n = CopyOnWrite(op); auto n = CopyOnWrite(op);
n->value = std::move(value);
n->indices = std::move(indices); n->indices = std::move(indices);
return Stmt(n); return Stmt(n);
} }
} }
Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) {
Region bounds = Internal::Mutate(this, op->bounds);
PrimExpr condition = this->VisitExpr(op->condition);
Stmt body = this->VisitStmt(op->body);
if (bounds.same_as(op->bounds) &&
condition.same_as(op->condition) &&
body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->bounds = std::move(bounds);
n->condition = std::move(condition);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
Array<PrimExpr> args = Internal::Mutate(this, op->args); Array<PrimExpr> args = Internal::Mutate(this, op->args);
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
......
...@@ -75,15 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute") ...@@ -75,15 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute")
} }
}); });
TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args.size() <= 3) {
*ret = StorageFlatten(args[0], args[1], args[2]);
} else {
*ret = StorageFlatten(args[0], args[1], args[2], args[3]);
}
});
TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
.set_body_typed .set_body_typed
([](const Stmt& stmt, ([](const Stmt& stmt,
...@@ -116,7 +107,6 @@ REGISTER_PASS(ConvertSSA); ...@@ -116,7 +107,6 @@ REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA); REGISTER_PASS(VerifySSA);
REGISTER_PASS(Inline); REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform); REGISTER_PASS(IRTransform);
REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(VerifyCompactBuffer);
......
...@@ -21,9 +21,12 @@ ...@@ -21,9 +21,12 @@
* \file inject_prefetch.cc * \file inject_prefetch.cc
*/ */
// Inject prefetch op in HalideIR // Inject prefetch op in HalideIR
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/transform.h>
#include <tvm/arith/bound.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <unordered_set> #include <unordered_set>
...@@ -39,9 +42,9 @@ class PrefetchInjector : public StmtMutator { ...@@ -39,9 +42,9 @@ class PrefetchInjector : public StmtMutator {
Stmt ret = StmtMutator::VisitStmt_(op); Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>(); op = ret.as<AttrStmtNode>();
if (op && op->attr_key == attr::prefetch_scope) { if (op && op->attr_key == attr::prefetch_scope) {
te::Tensor ts = Downcast<te::Tensor>(op->node); Buffer buffer = Downcast<Buffer>(op->node);
CHECK_NE(loop_nest_.size(), 0U); CHECK_NE(loop_nest_.size(), 0U);
Domain domain = DomainTouched(op->body, ts, true, false); Domain domain = DomainTouched(op->body, buffer, true, false);
Region region; Region region;
auto iter_var = loop_nest_.back().get(); auto iter_var = loop_nest_.back().get();
...@@ -49,7 +52,7 @@ class PrefetchInjector : public StmtMutator { ...@@ -49,7 +52,7 @@ class PrefetchInjector : public StmtMutator {
for (Range r : domain) { for (Range r : domain) {
if (!r.defined()) { if (!r.defined()) {
LOG(WARNING) << "Cannot decide prefetch region for " << ts; LOG(WARNING) << "Cannot decide prefetch region for " << buffer;
return op->body; return op->body;
} }
Range res(EvalSet(r, vectorized_).cover_range(none)); Range res(EvalSet(r, vectorized_).cover_range(none));
...@@ -58,7 +61,7 @@ class PrefetchInjector : public StmtMutator { ...@@ -58,7 +61,7 @@ class PrefetchInjector : public StmtMutator {
vectorized_.erase(iter_var); vectorized_.erase(iter_var);
Stmt prefetch = PrefetchNode::make(ts->op, ts->value_index, ts->dtype, region); Stmt prefetch = Prefetch(buffer, region);
return SeqStmt({prefetch, op->body}); return SeqStmt({prefetch, op->body});
} }
return ret; return ret;
...@@ -90,5 +93,22 @@ Stmt InjectPrefetch(Stmt stmt) { ...@@ -90,5 +93,22 @@ Stmt InjectPrefetch(Stmt stmt) {
return PrefetchInjector()(std::move(stmt)); return PrefetchInjector()(std::move(stmt));
} }
namespace transform {
Pass InjectPrefetch() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = PrefetchInjector()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch")
.set_body_typed(InjectPrefetch);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -19,22 +19,24 @@ ...@@ -19,22 +19,24 @@
/*! /*!
* \file storage_flatten.cc * \file storage_flatten.cc
* \brief Flattens storage from multi-dimensional array to 1D buffer access
*/ */
// Flattens storage from multi-dimensional array to 1D // The pass definition originates from Halide pipeline.
// buffer access as in Halide pipeline.
#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/transform.h>
#include <tvm/tir/buffer.h> #include <tvm/tir/buffer.h>
#include <tvm/target/target_info.h> #include <tvm/target/target_info.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <unordered_map> #include <unordered_map>
#include "ir_util.h" #include "../pass/ir_util.h"
#include "arg_binder.h" #include "../pass/arg_binder.h"
#include "../../arith/compute_expr.h" #include "../../arith/compute_expr.h"
#include "../../arith/ir_visitor_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h"
#include "../../runtime/thread_storage_scope.h" #include "../../runtime/thread_storage_scope.h"
...@@ -49,16 +51,17 @@ using intrinsic::tvm_address_of; ...@@ -49,16 +51,17 @@ using intrinsic::tvm_address_of;
class StorageFlattener : public StmtExprMutator { class StorageFlattener : public StmtExprMutator {
public: public:
explicit StorageFlattener(Map<te::Tensor, Buffer> extern_buffer, explicit StorageFlattener(const Map<Var, Buffer>& extern_buffer_map,
int cache_line_size, bool create_bound_attributes, int cache_line_size,
IRVisitorWithAnalyzer* bounded_analyzer) bool create_bound_attributes,
: bounded_analyzer_(bounded_analyzer), IRVisitorWithAnalyzer* bound_analyzer)
: bound_analyzer_(bound_analyzer),
create_bound_attributes_(create_bound_attributes) { create_bound_attributes_(create_bound_attributes) {
for (auto kv : extern_buffer) { for (auto kv : extern_buffer_map) {
BufferEntry e; BufferEntry e;
e.buffer = kv.second; e.buffer = kv.second;
e.external = true; e.external = true;
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e; buf_map_[kv.second] = e;
} }
cache_line_size_ = cache_line_size; cache_line_size_ = cache_line_size;
} }
...@@ -82,17 +85,14 @@ class StorageFlattener : public StmtExprMutator { ...@@ -82,17 +85,14 @@ class StorageFlattener : public StmtExprMutator {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value; storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
return this->VisitStmt(op->body); return this->VisitStmt(op->body);
} else if (op->attr_key == attr::double_buffer_scope && } else if (op->attr_key == attr::double_buffer_scope &&
op->node->IsInstance<te::OperationNode>()) { op->node->IsInstance<tir::BufferNode>()) {
auto func = Downcast<te::Operation>(op->node); auto buffer = Downcast<tir::Buffer>(op->node);
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
for (int i = 0; i < func->num_outputs(); ++i) { auto it = buf_map_.find(buffer);
TensorKey key{func, i};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end()) CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f; << "Cannot find allocated buffer for " << buffer;
body = AttrStmtNode::make( body = AttrStmtNode::make(
it->second.buffer->data, op->attr_key, op->value, body); it->second.buffer->data, op->attr_key, op->value, std::move(body));
}
return body; return body;
} else if (op->attr_key == attr::thread_extent) { } else if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
...@@ -104,11 +104,10 @@ class StorageFlattener : public StmtExprMutator { ...@@ -104,11 +104,10 @@ class StorageFlattener : public StmtExprMutator {
} else if (op->attr_key == attr::buffer_bind_scope) { } else if (op->attr_key == attr::buffer_bind_scope) {
return HandleBufferBindScope(op); return HandleBufferBindScope(op);
} else if (op->attr_key == attr::buffer_dim_align) { } else if (op->attr_key == attr::buffer_dim_align) {
auto tensor = Downcast<te::Tensor>(op->node); auto buffer = Downcast<tir::Buffer>(op->node);
const CallNode* tuple = op->value.as<CallNode>(); const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index}; auto& vinfo = dim_align_[buffer];
auto& vinfo = dim_align_[key];
int dim = tuple->args[0].as<IntImmNode>()->value; int dim = tuple->args[0].as<IntImmNode>()->value;
if (static_cast<size_t>(dim) >= vinfo.size()) { if (static_cast<size_t>(dim) >= vinfo.size()) {
vinfo.resize(dim + 1); vinfo.resize(dim + 1);
...@@ -122,18 +121,21 @@ class StorageFlattener : public StmtExprMutator { ...@@ -122,18 +121,21 @@ class StorageFlattener : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(op); return StmtExprMutator::VisitStmt_(op);
} }
Stmt VisitStmt_(const ProvideNode* op) final { Stmt VisitStmt_(const BufferStoreNode* op) final {
if (create_bound_attributes_) if (create_bound_attributes_) shape_collector_.clear();
shape_collector_.clear();
Stmt stmt = StmtExprMutator::VisitStmt_(op); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ProvideNode>(); op = stmt.as<BufferStoreNode>();
TensorKey key{op->func, op->value_index};
const auto& key = op->buffer;
auto it = buf_map_.find(key); auto it = buf_map_.find(key);
CHECK(it != buf_map_.end()) CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f; << "Cannot find allocated buffer for " << key;
const BufferEntry& e = it->second; const BufferEntry& e = it->second;
CHECK(!e.released) CHECK(!e.released)
<< "Read a buffer that is already out of scope"; << "Read a buffer that is already out of scope";
if (is_opengl_) { if (is_opengl_) {
return EvaluateNode::make(CallNode::make( return EvaluateNode::make(CallNode::make(
DataType(), DataType(),
...@@ -141,7 +143,7 @@ class StorageFlattener : public StmtExprMutator { ...@@ -141,7 +143,7 @@ class StorageFlattener : public StmtExprMutator {
{e.buffer->data, op->value}, {e.buffer->data, op->value},
CallNode::Intrinsic)); CallNode::Intrinsic));
} else { } else {
Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value); Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
shape_collector_.push_back( shape_collector_.push_back(
std::make_pair(e.buffer->data, e.buffer->shape)); std::make_pair(e.buffer->data, e.buffer->shape));
...@@ -158,8 +160,9 @@ class StorageFlattener : public StmtExprMutator { ...@@ -158,8 +160,9 @@ class StorageFlattener : public StmtExprMutator {
} }
} }
Stmt VisitStmt_(const RealizeNode* op) final { Stmt VisitStmt_(const BufferRealizeNode* op) final {
TensorKey key{op->func, op->value_index}; const auto& key = op->buffer;
if (buf_map_.count(key)) { if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external); CHECK(buf_map_.at(key).external);
return this->VisitStmt(op->body); return this->VisitStmt(op->body);
...@@ -172,10 +175,9 @@ class StorageFlattener : public StmtExprMutator { ...@@ -172,10 +175,9 @@ class StorageFlattener : public StmtExprMutator {
shape.push_back(r->extent); shape.push_back(r->extent);
} }
// deduce current storage scope. // deduce current storage scope.
auto it = storage_scope_.find(op->func.get()); auto it = storage_scope_.find(op->buffer.get());
CHECK(it != storage_scope_.end()) CHECK(it != storage_scope_.end())
<< "Cannot find storage scope of " << op->func << "Cannot find storage scope of " << op->buffer;
<< " value_index=" << op->value_index;
StorageScope skey; StorageScope skey;
const std::string& strkey = it->second; const std::string& strkey = it->second;
if (strkey.length() == 0) { if (strkey.length() == 0) {
...@@ -188,13 +190,14 @@ class StorageFlattener : public StmtExprMutator { ...@@ -188,13 +190,14 @@ class StorageFlattener : public StmtExprMutator {
} }
// use small alignment for small arrays // use small alignment for small arrays
auto dtype = op->buffer->dtype;
int32_t const_size = AllocateNode::constant_allocation_size(shape); int32_t const_size = AllocateNode::constant_allocation_size(shape);
int align = GetTempAllocaAlignment(op->dtype, const_size); int align = GetTempAllocaAlignment(dtype, const_size);
if (skey.tag.length() != 0) { if (skey.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(skey.to_string()); MemoryInfo info = GetMemoryInfo(skey.to_string());
if (info.defined()) { if (info.defined()) {
align = (info->max_simd_bits + op->dtype.bits() - 1) / op->dtype.bits(); align = (info->max_simd_bits + dtype.bits() - 1) / dtype.bits();
CHECK_LE(const_size * op->dtype.bits(), info->max_num_bits) CHECK_LE(const_size * dtype.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag " << skey.to_string(); << "Allocation exceed bound of memory tag " << skey.to_string();
} }
} }
...@@ -210,7 +213,7 @@ class StorageFlattener : public StmtExprMutator { ...@@ -210,7 +213,7 @@ class StorageFlattener : public StmtExprMutator {
PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
stride = tir::Simplify(stride); stride = bound_analyzer_->Simplify(stride);
} }
rstrides.push_back(stride); rstrides.push_back(stride);
stride = stride * shape[dim]; stride = stride * shape[dim];
...@@ -219,9 +222,9 @@ class StorageFlattener : public StmtExprMutator { ...@@ -219,9 +222,9 @@ class StorageFlattener : public StmtExprMutator {
} }
e.buffer = BufferNode::make( e.buffer = BufferNode::make(
Var(key.GetName(), DataType::Handle()), Var(op->buffer->data->name_hint, DataType::Handle()),
op->dtype, shape, strides, PrimExpr(), op->buffer->dtype, shape, strides, PrimExpr(),
key.GetName(), skey.to_string(), op->buffer->name, skey.to_string(),
align, 0, kDefault); align, 0, kDefault);
buf_map_[key] = e; buf_map_[key] = e;
...@@ -285,14 +288,15 @@ class StorageFlattener : public StmtExprMutator { ...@@ -285,14 +288,15 @@ class StorageFlattener : public StmtExprMutator {
} }
} }
PrimExpr VisitExpr_(const CallNode* op) final { PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op); PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>(); op = expr.as<BufferLoadNode>();
if (op != nullptr && op->call_type == CallNode::Halide) {
TensorKey key{op->func, op->value_index}; const auto& key = op->buffer;
auto it = buf_map_.find(key); auto it = buf_map_.find(key);
CHECK(it != buf_map_.end()) CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f; << "Cannot find allocated buffer for " << key;
const BufferEntry& e = it->second; const BufferEntry& e = it->second;
CHECK(!e.released) CHECK(!e.released)
<< "Read a buffer that is already out of scope"; << "Read a buffer that is already out of scope";
...@@ -301,20 +305,19 @@ class StorageFlattener : public StmtExprMutator { ...@@ -301,20 +305,19 @@ class StorageFlattener : public StmtExprMutator {
shape_collector_.push_back( shape_collector_.push_back(
std::make_pair(e.buffer->data, e.buffer->shape)); std::make_pair(e.buffer->data, e.buffer->shape));
} }
return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype); return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype);
} else {
return expr;
}
} }
Stmt VisitStmt_(const PrefetchNode *op) final { Stmt VisitStmt_(const PrefetchNode *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<PrefetchNode>(); op = stmt.as<PrefetchNode>();
CHECK(op != nullptr); CHECK(op != nullptr);
TensorKey key{op->func, op->value_index};
const auto& key = op->buffer;
auto it = buf_map_.find(key); auto it = buf_map_.find(key);
CHECK(it != buf_map_.end()) CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f; << "Cannot find allocated buffer for " << key;
const BufferEntry& e = it->second; const BufferEntry& e = it->second;
CHECK(!e.released) CHECK(!e.released)
...@@ -340,7 +343,7 @@ class StorageFlattener : public StmtExprMutator { ...@@ -340,7 +343,7 @@ class StorageFlattener : public StmtExprMutator {
for (int i = op->bounds.size() - 1; i > starts; --i) { for (int i = op->bounds.size() - 1; i > starts; --i) {
args.push_back(op->bounds[i]->min); args.push_back(op->bounds[i]->min);
} }
auto &func_name = op->func->func_name(); auto &func_name = op->buffer->name;
vars.push_back(Var( vars.push_back(Var(
"prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32)));
args.push_back(op->bounds[starts]->min + stride * vars.back()); args.push_back(op->bounds[starts]->min + stride * vars.back());
...@@ -358,7 +361,7 @@ class StorageFlattener : public StmtExprMutator { ...@@ -358,7 +361,7 @@ class StorageFlattener : public StmtExprMutator {
PrimExpr address = CallNode::make( PrimExpr address = CallNode::make(
DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
PrimExpr prefetch = CallNode::make( PrimExpr prefetch = CallNode::make(
op->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic);
stmt = EvaluateNode::make(prefetch); stmt = EvaluateNode::make(prefetch);
PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
...@@ -367,6 +370,26 @@ class StorageFlattener : public StmtExprMutator { ...@@ -367,6 +370,26 @@ class StorageFlattener : public StmtExprMutator {
return stmt; return stmt;
} }
PrimExpr VisitExpr_(const CallNode* op) final {
CHECK(op->call_type != CallNode::Halide)
<< "Cannot handle Halide calls "
<< " please run SchedulePostProcToPrimFunc first";
return StmtExprMutator::VisitExpr_(op);
}
Stmt VisitStmt_(const ProvideNode* op) final {
LOG(FATAL) << "Cannot handle Provide "
<< " please run SchedulePostProcToPrimFunc first";
return Stmt();
}
Stmt VisitStmt_(const RealizeNode* op) final {
LOG(FATAL) << "Cannot handle Realize "
<< " please run SchedulePostProcToPrimFunc first";
return Stmt();
}
private: private:
// The specific tensor data layout is not determined before // The specific tensor data layout is not determined before
// StorageFlatten pass. We use buffer_bind_scope // StorageFlatten pass. We use buffer_bind_scope
...@@ -406,14 +429,16 @@ class StorageFlattener : public StmtExprMutator { ...@@ -406,14 +429,16 @@ class StorageFlattener : public StmtExprMutator {
Array<ObjectRef> arr = Downcast<Array<ObjectRef> > (op->node); Array<ObjectRef> arr = Downcast<Array<ObjectRef> > (op->node);
CHECK_EQ(arr.size(), 2U); CHECK_EQ(arr.size(), 2U);
const BufferNode* buffer = arr[0].as<BufferNode>(); const BufferNode* buffer = arr[0].as<BufferNode>();
const te::TensorNode* tensor = arr[1].as<te::TensorNode>(); const BufferNode* target = arr[1].as<BufferNode>();
const CallNode* tuple = op->value.as<CallNode>(); const CallNode* tuple = op->value.as<CallNode>();
CHECK(buffer && tensor); CHECK(buffer && target);
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index}; auto key = GetRef<Buffer>(target);
CHECK(buf_map_.count(key))
<< "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index; auto it = buf_map_.find(key);
const BufferEntry& be = buf_map_.at(key); CHECK(it != buf_map_.end())
<< "Cannot find buffer of " << key;
const BufferEntry& be = it->second;
CHECK(!be.released); CHECK(!be.released);
CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
Array<PrimExpr> begins, extents; Array<PrimExpr> begins, extents;
...@@ -426,7 +451,7 @@ class StorageFlattener : public StmtExprMutator { ...@@ -426,7 +451,7 @@ class StorageFlattener : public StmtExprMutator {
} else { } else {
for (size_t i = 0; i < tuple->args.size(); i += 2) { for (size_t i = 0; i < tuple->args.size(); i += 2) {
begins.push_back(tuple->args[i]); begins.push_back(tuple->args[i]);
auto new_extent = bounded_analyzer_->Simplify(tuple->args[i+1]); auto new_extent = bound_analyzer_->Simplify(tuple->args[i+1]);
extents.push_back(new_extent); extents.push_back(new_extent);
} }
} }
...@@ -451,6 +476,7 @@ class StorageFlattener : public StmtExprMutator { ...@@ -451,6 +476,7 @@ class StorageFlattener : public StmtExprMutator {
} }
return body; return body;
} }
// The buffer entry in the flatten map // The buffer entry in the flatten map
struct DimAlignInfo { struct DimAlignInfo {
int align_factor{0}; int align_factor{0};
...@@ -509,9 +535,10 @@ class StorageFlattener : public StmtExprMutator { ...@@ -509,9 +535,10 @@ class StorageFlattener : public StmtExprMutator {
// Variable remap // Variable remap
std::unordered_map<const VarNode*, PrimExpr> var_remap_; std::unordered_map<const VarNode*, PrimExpr> var_remap_;
// Buffer map // Buffer map
std::unordered_map<TensorKey, BufferEntry> buf_map_; std::unordered_map<Buffer, BufferEntry, ObjectHash, ObjectEqual> buf_map_;
// Dimension alignment // Dimension alignment
std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_; std::unordered_map<Buffer, std::vector<DimAlignInfo>,
ObjectHash, ObjectEqual> dim_align_;
// Storage scope // Storage scope
std::unordered_map<const Object*, std::string> storage_scope_; std::unordered_map<const Object*, std::string> storage_scope_;
// The current thread scope. // The current thread scope.
...@@ -520,7 +547,7 @@ class StorageFlattener : public StmtExprMutator { ...@@ -520,7 +547,7 @@ class StorageFlattener : public StmtExprMutator {
std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_; std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_;
// bounds populator. We really need the analyzer from it. // bounds populator. We really need the analyzer from it.
// However // However
IRVisitorWithAnalyzer* bounded_analyzer_; IRVisitorWithAnalyzer* bound_analyzer_;
// The size of cacheline // The size of cacheline
int cache_line_size_; int cache_line_size_;
// The current stage is an OpenGL shader. // The current stage is an OpenGL shader.
...@@ -529,15 +556,37 @@ class StorageFlattener : public StmtExprMutator { ...@@ -529,15 +556,37 @@ class StorageFlattener : public StmtExprMutator {
bool create_bound_attributes_{false}; bool create_bound_attributes_{false};
}; };
Stmt StorageFlatten(Stmt stmt, Map<te::Tensor, Buffer> extern_buffer, PrimFunc StorageFlatten(PrimFunc func,
int cache_line_size, bool create_bound_attributes) { int cache_line_size,
IRVisitorWithAnalyzer bounded_analyzer; bool create_bound_attributes) {
bounded_analyzer(stmt); auto fptr = func.CopyOnWrite();
stmt =
StorageFlattener(extern_buffer, cache_line_size, IRVisitorWithAnalyzer bound_analyzer;
create_bound_attributes, &bounded_analyzer)(std::move(stmt)); bound_analyzer(fptr->body);
return stmt; fptr->body = StorageFlattener(fptr->buffer_map,
cache_line_size,
create_bound_attributes,
&bound_analyzer)(std::move(fptr->body));
return func;
}
namespace transform {
// TODO(tvm-team): consolidate configs to the PassContext
Pass StorageFlatten(int cache_line_size,
bool create_bound_attributes) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return StorageFlatten(
std::move(f), cache_line_size, create_bound_attributes);
};
return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {});
} }
TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten")
.set_body_typed(StorageFlatten);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -22,21 +22,25 @@ def test_domain_touched(): ...@@ -22,21 +22,25 @@ def test_domain_touched():
j = te.var('j') j = te.var('j')
n = tvm.runtime.convert(100) n = tvm.runtime.convert(100)
m = te.var('m') m = te.var('m')
a = te.placeholder((n, m), name = 'a')
b = te.placeholder((n, m), name = 'b') a = tvm.tir.decl_buffer((n, m), name='a')
b = tvm.tir.decl_buffer((n, m), name='b')
ir = tvm.tir.For( ir = tvm.tir.For(
i, 0, n, 0, 0, i, 0, n, 0, 0,
tvm.tir.For(j, 0, m, 0, 0, tvm.tir.For(j, 0, m, 0, 0,
tvm.tir.Provide( tvm.tir.BufferStore(
a.op, a,
0, tvm.tir.BufferLoad(b, [i - 1, j + 1]) +
tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) + tvm.tir.BufferLoad(a, [i - 1, j - 1]),
tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0),
[i, j] [i, j]
) )
) )
) )
a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False) a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)
assert a_domain_r[0].min.value == -1 assert a_domain_r[0].min.value == -1
assert a_domain_r[0].extent.value == 100 assert a_domain_r[0].extent.value == 100
assert a_domain_r[1].min.value == -1 assert a_domain_r[1].min.value == -1
......
...@@ -48,9 +48,9 @@ def test_split_uneven_unique_likely(): ...@@ -48,9 +48,9 @@ def test_split_uneven_unique_likely():
x, y = c.op.axis x, y = c.op.axis
sch = te.create_schedule(c.op) sch = te.create_schedule(c.op)
xo, xi = sch[c].split(x, 5) xo, xi = sch[c].split(x, 5)
stmt = tvm.lower(sch, [a, b, c], simple_mode=True) stmt = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse) assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse)
assert str(stmt.body.body.body).count("likely") == 1
if __name__ == "__main__": if __name__ == "__main__":
test_lower_rfactor() test_lower_rfactor()
......
...@@ -365,7 +365,7 @@ def test_bind(): ...@@ -365,7 +365,7 @@ def test_bind():
a = te.placeholder((8, 4), 'float32') a = te.placeholder((8, 4), 'float32')
c = foo(a) c = foo(a)
s = te.create_schedule(c.op) s = te.create_schedule(c.op)
ir = tvm.lower(s, [a, c], simple_mode=True) ir = tvm.lower(s, [a, c])
func, ins, outs = run_and_check(foo, [a], target='cuda') func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda') run_and_check(func, ins, outs=outs, target='cuda')
...@@ -517,7 +517,7 @@ def test_upstream(): ...@@ -517,7 +517,7 @@ def test_upstream():
c = te.compute((20, ), lambda x: a[x] + b[x]) c = te.compute((20, ), lambda x: a[x] + b[x])
d = upstream(c) d = upstream(c)
sch = te.create_schedule([c.op, d.op]) sch = te.create_schedule([c.op, d.op])
ir = tvm.lower(sch, [a, b, d], simple_mode=True) ir = tvm.lower(sch, [a, b, d])
func = tvm.build(sch, [a, b, d]) func = tvm.build(sch, [a, b, d])
assert(func) assert(func)
...@@ -730,7 +730,7 @@ def test_schedule(): ...@@ -730,7 +730,7 @@ def test_schedule():
joo, joi = sch[c].split(jo, 4) joo, joi = sch[c].split(jo, 4)
sch[c].vectorize(ji) sch[c].vectorize(ji)
sch[c].reorder(ii, io, joo, joi, ji) sch[c].reorder(ii, io, joo, joi, ji)
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(ir, tvm.tir.AttrStmt) assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.tir.For) assert isinstance(ir, tvm.tir.For)
...@@ -751,7 +751,7 @@ def test_schedule(): ...@@ -751,7 +751,7 @@ def test_schedule():
# Test fuse # Test fuse
sch = te.create_schedule(c.op) sch = te.create_schedule(c.op)
sch[c].fuse(c.op.axis[0], c.op.axis[1]) sch[c].fuse(c.op.axis[0], c.op.axis[1])
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(ir, tvm.tir.AttrStmt) assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.tir.For) assert isinstance(ir, tvm.tir.For)
......
...@@ -283,7 +283,7 @@ def test_tensor_intrin_scalar_params(): ...@@ -283,7 +283,7 @@ def test_tensor_intrin_scalar_params():
# Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs
C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
s = te.create_schedule(C.op) s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, C], simple_mode=True) stmt = tvm.lower(s, [A, C])["main"].body
assert isinstance(stmt.body.body, tvm.tir.Evaluate) assert isinstance(stmt.body.body, tvm.tir.Evaluate)
assert len(stmt.body.body.value.args) == 5 assert len(stmt.body.body.value.args) == 5
assert str(stmt.body.body.value.args[3]) == "(i*i)" assert str(stmt.body.body.value.args[3]) == "(i*i)"
......
...@@ -28,6 +28,9 @@ def test_schedule0(): ...@@ -28,6 +28,9 @@ def test_schedule0():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[A, A1], stmt, None)
assert isinstance(func, tvm.tir.PrimFunc)
def test_schedule1(): def test_schedule1():
...@@ -43,6 +46,10 @@ def test_schedule1(): ...@@ -43,6 +46,10 @@ def test_schedule1():
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[A, A1], stmt, None)
assert isinstance(func, tvm.tir.PrimFunc)
def test_schedule2(): def test_schedule2():
m = te.var('m') m = te.var('m')
...@@ -57,6 +64,9 @@ def test_schedule2(): ...@@ -57,6 +64,9 @@ def test_schedule2():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[A, A2], stmt, None)
assert isinstance(func, tvm.tir.PrimFunc)
def test_schedule_scan(): def test_schedule_scan():
...@@ -77,6 +87,7 @@ def test_schedule_scan(): ...@@ -77,6 +87,7 @@ def test_schedule_scan():
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
def test_inline_multi_reduce(): def test_inline_multi_reduce():
def argmax_comp(x, y): def argmax_comp(x, y):
idx = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) idx = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
...@@ -510,19 +521,19 @@ def test_local_stage_predicate(): ...@@ -510,19 +521,19 @@ def test_local_stage_predicate():
return ret return ret
# local vs. threadIdx # local vs. threadIdx
s = schedule(tx, "local") s = schedule(tx, "local")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body lowered_body = tvm.lower(s, [A, C])["main"].body
assert (not any( assert (not any(
collect_visit(lowered_body, collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse)))) lambda x: isinstance(x, tvm.tir.IfThenElse))))
# local vs. vthread # local vs. vthread
s = schedule(vx, "local") s = schedule(vx, "local")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body lowered_body = tvm.lower(s, [A, C])["main"].body
assert (not any( assert (not any(
collect_visit(lowered_body, collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse)))) lambda x: isinstance(x, tvm.tir.IfThenElse))))
# shared vs. blockIdx # shared vs. blockIdx
s = schedule(by, "shared") s = schedule(by, "shared")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body lowered_body = tvm.lower(s, [A, C])["main"].body
assert (not any( assert (not any(
collect_visit(lowered_body, collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse)))) lambda x: isinstance(x, tvm.tir.IfThenElse))))
...@@ -548,7 +559,7 @@ def test_local_stage_predicate2(): ...@@ -548,7 +559,7 @@ def test_local_stage_predicate2():
s[AA].compute_at(s[C], ooc) s[AA].compute_at(s[C], ooc)
oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32) oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32)
s[AA].bind(iaa, thread_x) s[AA].bind(iaa, thread_x)
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body lowered_body = tvm.lower(s, [A, C])["main"].body
def collect_visit(stmt, f): def collect_visit(stmt, f):
ret = [] ret = []
......
...@@ -128,7 +128,7 @@ def test_tensor_compute1(): ...@@ -128,7 +128,7 @@ def test_tensor_compute1():
lambda i: vadd(A[i, 0:factor], B[i, 0:factor])) lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))
s = te.create_schedule(C.op) s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True) stmt = tvm.lower(s, [A, B, C])["main"].body
assert isinstance(stmt.body, tvm.tir.Evaluate) assert isinstance(stmt.body, tvm.tir.Evaluate)
def test_tensor_compute2(): def test_tensor_compute2():
...@@ -171,7 +171,7 @@ def test_tensor_compute2(): ...@@ -171,7 +171,7 @@ def test_tensor_compute2():
lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k)) lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k))
s = te.create_schedule(C.op) s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True) stmt = tvm.lower(s, [A, B, C])["main"].body
assert isinstance(stmt.body.body[0], tvm.tir.Evaluate) assert isinstance(stmt.body.body[0], tvm.tir.Evaluate)
assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate) assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate)
......
...@@ -24,29 +24,6 @@ gpu_devices = ["cuda", "opencl", "metal", "vulkan"] ...@@ -24,29 +24,6 @@ gpu_devices = ["cuda", "opencl", "metal", "vulkan"]
other_devices = ["llvm", "ext_dev"] other_devices = ["llvm", "ext_dev"]
def lower(sch, args):
binds = {}
arg_list = []
for x in args:
if isinstance(x, te.tensor.Tensor):
buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
else:
raise ValueError("args must be Tensor, Buffer or Var")
sch = sch.normalize()
bounds = tvm.te.schedule.InferBound(sch)
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String("test"))
mod = tvm.IRModule({"test": f})
return mod
# All computations are bound. # All computations are bound.
# So VerifyMemory pass is expected to succeed. # So VerifyMemory pass is expected to succeed.
# #
...@@ -61,7 +38,7 @@ def test_verify_memory_all_bind(): ...@@ -61,7 +38,7 @@ def test_verify_memory_all_bind():
s[B].bind(bx, te.thread_axis("blockIdx.x")) s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x")) s[B].bind(tx, te.thread_axis("threadIdx.x"))
mod = lower(s, [A, B]) mod = tvm.lower(s, [A, B])
for dev_type in gpu_devices + other_devices: for dev_type in gpu_devices + other_devices:
binded_mod = tvm.tir.transform.Apply( binded_mod = tvm.tir.transform.Apply(
...@@ -81,7 +58,7 @@ def test_verify_memory_not_bind(): ...@@ -81,7 +58,7 @@ def test_verify_memory_not_bind():
# B is not bound to threads. # B is not bound to threads.
s = te.create_schedule(B.op) s = te.create_schedule(B.op)
mod = lower(s, [A, B]) mod = tvm.lower(s, [A, B])
for dev_type in gpu_devices: for dev_type in gpu_devices:
binded_mod = tvm.tir.transform.Apply( binded_mod = tvm.tir.transform.Apply(
...@@ -111,7 +88,7 @@ def test_verify_memory_partially_bind(): ...@@ -111,7 +88,7 @@ def test_verify_memory_partially_bind():
s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = lower(s, [A, B, C, D]) mod = tvm. lower(s, [A, B, C, D])
for dev_type in gpu_devices: for dev_type in gpu_devices:
binded_mod = tvm.tir.transform.Apply( binded_mod = tvm.tir.transform.Apply(
......
...@@ -194,9 +194,9 @@ def test_stmt_constructor(): ...@@ -194,9 +194,9 @@ def test_stmt_constructor():
assert x.then_case.value.value == 11 assert x.then_case.value.value == 11
assert x.else_case == nop assert x.else_case == nop
x = tvm.tir.Prefetch(None, 1, "float32", []) b = tvm.tir.decl_buffer((1, 2))
x = tvm.tir.Prefetch(b, [])
assert isinstance(x, tvm.tir.Prefetch) assert isinstance(x, tvm.tir.Prefetch)
assert x.value_index == 1
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -28,7 +28,6 @@ def test_for(): ...@@ -28,7 +28,6 @@ def test_for():
A[j] = A[j] + 2 A[j] = A[j] + 2
body = ib.get() body = ib.get()
print(body)
assert isinstance(body, tvm.tir.AttrStmt) assert isinstance(body, tvm.tir.AttrStmt)
body = body.body body = body.body
assert isinstance(body, tvm.tir.Allocate) assert isinstance(body, tvm.tir.Allocate)
...@@ -59,14 +58,13 @@ def test_if(): ...@@ -59,14 +58,13 @@ def test_if():
assert body.else_case.index.value == 0 assert body.else_case.index.value == 0
def test_prefetch(): def test_prefetch():
A = te.placeholder((10, 20), name="A") A = tvm.tir.decl_buffer((10, 20), name="A")
ib = tvm.tir.ir_builder.create() ib = tvm.tir.ir_builder.create()
n = te.size_var("n") n = te.size_var("n")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
ib.emit( ib.emit(
tvm.tir.Prefetch( tvm.tir.Prefetch(A,
A.op, A.value_index, A.dtype,
[tvm.ir.Range.make_by_min_extent(i+1, 2), [tvm.ir.Range.make_by_min_extent(i+1, 2),
tvm.ir.Range.make_by_min_extent(0, 20)])) tvm.ir.Range.make_by_min_extent(0, 20)]))
body = ib.get() body = ib.get()
......
...@@ -301,6 +301,10 @@ def test_buffer_load_store(): ...@@ -301,6 +301,10 @@ def test_buffer_load_store():
s = tvm.tir.BufferStore(b, 0.1, [0]) s = tvm.tir.BufferStore(b, 0.1, [0])
assert isinstance(s, tvm.tir.BufferStore) assert isinstance(s, tvm.tir.BufferStore)
s = tvm.tir.BufferRealize(b, [tvm.ir.Range(0, 1)],
True, tvm.tir.Evaluate(0))
assert isinstance(s, tvm.tir.BufferRealize)
def test_intimm_cond(): def test_intimm_cond():
x = tvm.runtime.convert(1) x = tvm.runtime.convert(1)
......
...@@ -26,9 +26,10 @@ def test_copy2d(): ...@@ -26,9 +26,10 @@ def test_copy2d():
s[B].pragma(B.op.axis[0], "memcpy") s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') mod = tvm.IRModule.from_expr(func)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) mod = tvm.tir.transform.StorageFlatten(64)(mod)
def cb(src, dst, pad_before, pad_after, pad_value): def cb(src, dst, pad_before, pad_after, pad_value):
assert dst.strides[0] == l assert dst.strides[0] == l
assert dst.strides[1].value == 1 assert dst.strides[1].value == 1
...@@ -36,7 +37,6 @@ def test_copy2d(): ...@@ -36,7 +37,6 @@ def test_copy2d():
assert tuple(src.shape) == (m, l) assert tuple(src.shape) == (m, l)
return tvm.tir.Evaluate(0) return tvm.tir.Evaluate(0)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
...@@ -51,9 +51,11 @@ def test_copy_pad(): ...@@ -51,9 +51,11 @@ def test_copy_pad():
s[B].pragma(B.op.axis[0], "memcpy") s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def cb(src, dst, pad_before, pad_after, pad_value): def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0
assert pad_before[0].value == 1 assert pad_before[0].value == 1
...@@ -63,7 +65,6 @@ def test_copy_pad(): ...@@ -63,7 +65,6 @@ def test_copy_pad():
assert pad_value.value == 1.0 assert pad_value.value == 1.0
return tvm.tir.Evaluate(0) return tvm.tir.Evaluate(0)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
...@@ -75,9 +76,11 @@ def test_single_point_test(): ...@@ -75,9 +76,11 @@ def test_single_point_test():
s[B].pragma(B.op.axis[0], "memcpy") s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def cb(src, dst, pad_before, pad_after, pad_value): def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0
assert tvm.tir.ir_pass.Simplify(dst.elem_offset).value == 0 assert tvm.tir.ir_pass.Simplify(dst.elem_offset).value == 0
...@@ -85,7 +88,6 @@ def test_single_point_test(): ...@@ -85,7 +88,6 @@ def test_single_point_test():
assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1 assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.tir.Evaluate(0) return tvm.tir.Evaluate(0)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
...@@ -105,11 +107,12 @@ def test_copy_pad_split(): ...@@ -105,11 +107,12 @@ def test_copy_pad_split():
s[Apad].pragma(s[Apad].op.axis[0], "memcpy") s[Apad].pragma(s[Apad].op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) mod = tvm.IRModule.from_expr(func)
stmt = tvm.tir.ir_pass.Simplify(stmt) mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) mod = tvm.tir.transform.Simplify()(mod._move())
def cb(src, dst, pad_before, pad_after, pad_value): def cb(src, dst, pad_before, pad_after, pad_value):
assert(dst.elem_offset.value == 0) assert(dst.elem_offset.value == 0)
assert_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) assert_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1)
...@@ -121,12 +124,10 @@ def test_copy_pad_split(): ...@@ -121,12 +124,10 @@ def test_copy_pad_split():
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.tir.Evaluate(0) return tvm.tir.Evaluate(0)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
if __name__ == "__main__": if __name__ == "__main__":
test_copy2d() test_copy2d()
test_copy_pad() test_copy_pad()
......
...@@ -28,18 +28,16 @@ def test_makeapi(): ...@@ -28,18 +28,16 @@ def test_makeapi():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') mod = tvm.IRModule.from_expr(func)
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') mod = tvm.tir.transform.StorageFlatten(64)(mod)
Cb = tvm.tir.decl_buffer(C.shape, C.dtype, name='C') mod = tvm.tir.transform.Apply(
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) lambda f: f.with_attr({
"target": tvm.target.create("llvm"),
"global_symbol": "main",
}))(mod)
num_unpacked_args = 2 num_unpacked_args = 2
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr({
"global_symbol": "main",
"target": tvm.target.create("llvm")
}))
f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"] f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
assert(len(f.params) == 7) assert(len(f.params) == 7)
......
...@@ -40,8 +40,11 @@ def lower_sch(sch, args, target_bits): ...@@ -40,8 +40,11 @@ def lower_sch(sch, args, target_bits):
raise ValueError("args must be Tensor, Buffer or Var") raise ValueError("args must be Tensor, Buffer or Var")
bounds = te.schedule.InferBound(sch) bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds) stmt = te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
return lower_stmt(arg_list, stmt, target_bits) func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body
def test_basic(): def test_basic():
......
...@@ -30,11 +30,14 @@ def test_flatten2(): ...@@ -30,11 +30,14 @@ def test_flatten2():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
stmt = tvm.tir.ir_pass.Simplify(stmt) func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[Ab, A2b], stmt, {A: Ab, A2: A2b})
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def test_flatten_prefetch(): def test_flatten_prefetch():
A = te.placeholder((25, 100, 4), name = 'A') A = te.placeholder((25, 100, 4), name = 'A')
...@@ -42,8 +45,14 @@ def test_flatten_prefetch(): ...@@ -42,8 +45,14 @@ def test_flatten_prefetch():
i = te.size_var('i') i = te.size_var('i')
j = te.size_var('j') j = te.size_var('j')
region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]] region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]]
stmt = tvm.tir.Prefetch(A.op, 0, A.dtype, region) stmt = tvm.tir.Prefetch(_A, region)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: _A}, 64)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(
[_A], stmt, {A: _A})
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
stmt = mod["main"].body
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert stmt.extent.value == 2 assert stmt.extent.value == 2
assert isinstance(stmt.body, tvm.tir.For) assert isinstance(stmt.body, tvm.tir.For)
...@@ -62,12 +71,15 @@ def test_flatten_storage_align(): ...@@ -62,12 +71,15 @@ def test_flatten_storage_align():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
stmt = mod["main"].body
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert(stmt.body.extents[0].value == 17 * 8) assert(stmt.body.extents[0].value == 17 * 8)
def test_flatten_double_buffer(): def test_flatten_double_buffer():
dtype = 'int64' dtype = 'int64'
n = 100 n = 100
...@@ -87,7 +99,13 @@ def test_flatten_double_buffer(): ...@@ -87,7 +99,13 @@ def test_flatten_double_buffer():
C[j] = B[j] + 1 C[j] = B[j] + 1
stmt = ib.get() stmt = ib.get()
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {}, 64)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A, C], stmt))
mod = tvm.tir.transform.StorageFlatten(64)(mod)
stmt = mod["main"].body
stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2) stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2)
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate) assert isinstance(stmt.body.body, tvm.tir.Allocate)
...@@ -105,7 +123,7 @@ def test_flatten_double_buffer(): ...@@ -105,7 +123,7 @@ def test_flatten_double_buffer():
assert count[0] == 4 assert count[0] == 4
if __name__ == "__main__": if __name__ == "__main__":
test_flatten_storage_align()
test_flatten2() test_flatten2()
test_flatten_prefetch() test_flatten_storage_align()
test_flatten_double_buffer() test_flatten_double_buffer()
test_flatten_prefetch()
...@@ -30,11 +30,11 @@ def test_storage_share(): ...@@ -30,11 +30,11 @@ def test_storage_share():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body stmt = mod["main"].body
...@@ -166,11 +166,11 @@ def test_inplace_rule(): ...@@ -166,11 +166,11 @@ def test_inplace_rule():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body stmt = mod["main"].body
...@@ -201,11 +201,10 @@ def test_storage_combine(): ...@@ -201,11 +201,10 @@ def test_storage_combine():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') mod = tvm.IRModule.from_expr(func)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body stmt = mod["main"].body
...@@ -238,11 +237,9 @@ def test_storage_share_gpu(): ...@@ -238,11 +237,9 @@ def test_storage_share_gpu():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A') func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt, None)
Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B') mod = tvm.IRModule.from_expr(func)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64) mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body stmt = mod["main"].body
...@@ -306,13 +303,11 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024): ...@@ -306,13 +303,11 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C')
Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb, Cc, Dd], stmt)) func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body stmt = mod["main"].body
...@@ -398,17 +393,11 @@ def test_inplace_rule3(): ...@@ -398,17 +393,11 @@ def test_inplace_rule3():
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
B0a = tvm.tir.decl_buffer(B0.shape, B0.dtype, name='B0') func = tvm.te.schedule.SchedulePostProcToPrimFunc(
B1a = tvm.tir.decl_buffer(B1.shape, B1.dtype, name='B1') [B0, B1, B2, B3, B4, B5, B], stmt, None)
B2a = tvm.tir.decl_buffer(B2.shape, B2.dtype, name='B2') mod = tvm.IRModule.from_expr(func)
B3a = tvm.tir.decl_buffer(B3.shape, B3.dtype, name='B3') mod = tvm.tir.transform.StorageFlatten(64)(mod)
B4a = tvm.tir.decl_buffer(B4.shape, B4.dtype, name='B4')
B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B3a, B4: B4a, B5: B5a, B: Bb}, 64)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([B0a, B1a, B2a, B3a, B4a, B5a, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body stmt = mod["main"].body
...@@ -547,7 +536,7 @@ def test_large_input(): ...@@ -547,7 +536,7 @@ def test_large_input():
c = te.compute(shape, lambda i, j: compute(a, b)[i, j]) c = te.compute(shape, lambda i, j: compute(a, b)[i, j])
c = te.compute(shape, lambda i, j: 1 + c[i, j]) c = te.compute(shape, lambda i, j: 1 + c[i, j])
s = te.create_schedule(c.op) s = te.create_schedule(c.op)
stmt = tvm.lower(s, [a, b, c], simple_mode=True) stmt = tvm.lower(s, [a, b, c])["main"].body
def verify(n): def verify(n):
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
assert n.extents[0].value == 268435456 assert n.extents[0].value == 268435456
......
...@@ -34,15 +34,15 @@ def test_thread_storage_sync(): ...@@ -34,15 +34,15 @@ def test_thread_storage_sync():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
cuda_target = tvm.target.create("cuda") cuda_target = tvm.target.create("cuda")
mod = tvm.IRModule.from_expr( mod = tvm.tir.transform.Apply(lambda f: f.with_attr({
tvm.tir.PrimFunc([Ab, A2b], stmt).with_attr({ "global_symbol": "test", "target": cuda_target}))(mod._move())
"global_symbol": "test", "target": cuda_target}))
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice) mod = tvm.IRModule.from_expr(fdevice)
......
...@@ -40,8 +40,6 @@ Before reading this tutorial, we assume readers have already known these topics ...@@ -40,8 +40,6 @@ Before reading this tutorial, we assume readers have already known these topics
take a look at ``python/tvm/build_module.py`` to get some basics. take a look at ``python/tvm/build_module.py`` to get some basics.
""" """
from __future__ import absolute_import, print_function
import tvm import tvm
from tvm import te from tvm import te
import numpy as np import numpy as np
...@@ -57,7 +55,7 @@ b = te.placeholder((n, ), name="b") ...@@ -57,7 +55,7 @@ b = te.placeholder((n, ), name="b")
c = te.compute((n, ), lambda i: a[i] + b[i], name='c') c = te.compute((n, ), lambda i: a[i] + b[i], name='c')
sch = te.create_schedule(c.op) sch = te.create_schedule(c.op)
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c])
print(ir) print(ir)
###################################################################### ######################################################################
...@@ -137,12 +135,8 @@ def vectorize(stmt): ...@@ -137,12 +135,8 @@ def vectorize(stmt):
# Glue to Lowering # Glue to Lowering
# ---------------- # ----------------
# So far, we are done with writing this IR transformation pass. What we need to do next is to glue # So far, we are done with writing this IR transformation pass. What we need to do next is to glue
# this pass to TVM's lower pass. We can first call this function directly as a sanity check. # this pass to TVM's lower pass.
# #
print(vectorize(ir))
#####################################################################
# In TVM, there is a property called ``BuildConfig``. You can use this property to customize your # In TVM, there is a property called ``BuildConfig``. You can use this property to customize your
# own lowering options. In this case, we inject the pass written above into the TVM standard lowering # own lowering options. In this case, we inject the pass written above into the TVM standard lowering
# pass by feeding **a list of tuple** as argument to ``add_lower_pass``. "Tuple" indicates different # pass by feeding **a list of tuple** as argument to ``add_lower_pass``. "Tuple" indicates different
...@@ -160,7 +154,7 @@ print(vectorize(ir)) ...@@ -160,7 +154,7 @@ print(vectorize(ir))
# #
with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg: with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg:
print(tvm.lower(sch, [a, b, c], simple_mode=True)) print(tvm.lower(sch, [a, b, c]))
##################################################################### #####################################################################
# Quick View # Quick View
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment