Commit 1400edac by Tianqi Chen Committed by GitHub

[IR] Include PrefetchIR (#189)

parent eaf0fde3
Subproject commit efe5b5cc3c89da5d5e39570f6776d39d8acacacc Subproject commit 41fe60a76fe6e5669540acf1ef3595bc38025157
...@@ -158,6 +158,11 @@ constexpr const char* device_context_type = "device_context_type"; ...@@ -158,6 +158,11 @@ constexpr const char* device_context_type = "device_context_type";
constexpr const char* loop_scope = "loop_scope"; constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */ /*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope"; constexpr const char* reduce_scope = "reduce_scope";
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
*/
constexpr const char* prefetch_scope = "prefetch_scope";
/*! \brief Mark of scan update scope */ /*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope"; constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */ /*! \brief Mark of scan init scope */
...@@ -371,6 +376,7 @@ using Halide::Internal::Provide; ...@@ -371,6 +376,7 @@ using Halide::Internal::Provide;
using Halide::Internal::Allocate; using Halide::Internal::Allocate;
using Halide::Internal::Free; using Halide::Internal::Free;
using Halide::Internal::Realize; using Halide::Internal::Realize;
using Halide::Internal::Prefetch;
using Halide::Internal::Block; using Halide::Internal::Block;
using Halide::Internal::IfThenElse; using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate; using Halide::Internal::Evaluate;
......
...@@ -17,6 +17,9 @@ namespace ir { ...@@ -17,6 +17,9 @@ namespace ir {
* You can use this as a more powerful Visitor, since it allows you to * You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function. * define function signatures of Visit Function.
* *
* This helps you to avoid to book-keep return value of Visitor via state,
* which can cause bugs easily when state is incorrectly maintained.
*
* \code * \code
* // A functor that set variable to b. and calculate results. * // A functor that set variable to b. and calculate results.
* class MyExprFunctor * class MyExprFunctor
...@@ -223,6 +226,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { ...@@ -223,6 +226,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Node* op, Args ...) { virtual R VisitStmtDefault_(const Node* op, Args ...) {
...@@ -245,6 +249,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { ...@@ -245,6 +249,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer); IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
IR_STMT_FUNCTOR_DISPATCH(Provide); IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize); IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Prefetch);
IR_STMT_FUNCTOR_DISPATCH(Block); IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(Evaluate); IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable; return vtable;
......
...@@ -66,6 +66,7 @@ class IRMutator { ...@@ -66,6 +66,7 @@ class IRMutator {
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s); virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s); virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Prefetch* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s); virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& s); virtual Stmt Mutate_(const Evaluate* op, const Stmt& s);
......
...@@ -116,6 +116,7 @@ class IRVisitor { ...@@ -116,6 +116,7 @@ class IRVisitor {
virtual void Visit_(const ProducerConsumer* op); virtual void Visit_(const ProducerConsumer* op);
virtual void Visit_(const Provide* op); virtual void Visit_(const Provide* op);
virtual void Visit_(const Realize* op); virtual void Visit_(const Realize* op);
virtual void Visit_(const Prefetch* op);
virtual void Visit_(const Block* op); virtual void Visit_(const Block* op);
virtual void Visit_(const Evaluate* op); virtual void Visit_(const Evaluate* op);
virtual void Visit_(const IntImm* op); virtual void Visit_(const IntImm* op);
......
...@@ -461,10 +461,16 @@ class IterVarAttrNode : public Node { ...@@ -461,10 +461,16 @@ class IterVarAttrNode : public Node {
IterVarType iter_type{kDataPar}; IterVarType iter_type{kDataPar};
/*! \brief The thread this iter Var binds, can be null */ /*! \brief The thread this iter Var binds, can be null */
IterVar bind_thread; IterVar bind_thread;
/*! \brief List of tensor to be prefetched in this loop */
Array<Tensor> prefetch_data;
/*! \brief The offset used in each prefetch */
Array<Expr> prefetch_offset;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type); v->Visit("iter_type", &iter_type);
v->Visit("bind_thread", &bind_thread); v->Visit("bind_thread", &bind_thread);
v->Visit("prefetch_data", &prefetch_data);
v->Visit("prefetch_offset", &prefetch_offset);
} }
static constexpr const char* _type_key = "IterVarAttr"; static constexpr const char* _type_key = "IterVarAttr";
......
...@@ -13,3 +13,4 @@ There can be internal header files within each module that sit in src. ...@@ -13,3 +13,4 @@ There can be internal header files within each module that sit in src.
- pass The optimization pass on the IR structure - pass The optimization pass on the IR structure
- codegen The code generator. - codegen The code generator.
- runtime Minimum runtime related codes - runtime Minimum runtime related codes
- contrib Contrib extension libraries
...@@ -212,7 +212,7 @@ void BoundDeducer::Deduce() { ...@@ -212,7 +212,7 @@ void BoundDeducer::Deduce() {
success = false; success = false;
return; return;
} }
// get the sign of every subexpr
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
Visit(expr_); Visit(expr_);
......
...@@ -55,14 +55,18 @@ MakeLoopNest(const Stage& stage, ...@@ -55,14 +55,18 @@ MakeLoopNest(const Stage& stage,
// Mark the iter var in the IR, to remember the point // Mark the iter var in the IR, to remember the point
if (bind_iv->thread_tag.length() == 0) { if (bind_iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial; ForType for_type = ForType::Serial;
IterVarAttr it_attr;
if (stage->iter_var_attrs.count(iv)) { if (stage->iter_var_attrs.count(iv)) {
switch (stage->iter_var_attrs[iv]->iter_type) { it_attr = stage->iter_var_attrs[iv];
}
if (it_attr.defined()) {
switch (it_attr->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break; case kUnrolled: for_type = ForType::Unrolled; break;
case kVectorized: for_type = ForType::Vectorized; break; case kVectorized: for_type = ForType::Vectorized; break;
case kParallelized: for_type = ForType::Parallel; break; case kParallelized: for_type = ForType::Parallel; break;
case kDataPar: break; case kDataPar: break;
default: LOG(FATAL) << "Unknown iter type" default: LOG(FATAL) << "Unknown iter type"
<< stage->iter_var_attrs[iv]->iter_type << it_attr->iter_type
<< " in the iter_var_attrs"; << " in the iter_var_attrs";
} }
} }
...@@ -85,6 +89,18 @@ MakeLoopNest(const Stage& stage, ...@@ -85,6 +89,18 @@ MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op)); LetStmt::make(var, new_value, no_op));
} }
if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
CHECK(!is_one(dom->extent))
<< "Cannot prefetch on trivial loop with extent=1";
CHECK_EQ(it_attr->prefetch_data.size(),
it_attr->prefetch_offset.size());
for (size_t i = 0; i < it_attr->prefetch_data.size(); ++i) {
nest[i + 1].emplace_back(
AttrStmt::make(it_attr->prefetch_data[i],
ir::attr::prefetch_scope,
it_attr->prefetch_offset[i], no_op));
}
}
} else if (bind_iv->thread_tag == "vthread") { } else if (bind_iv->thread_tag == "vthread") {
// virtual thread // virtual thread
// Always restrict threaded IterVar to starts from 0. // Always restrict threaded IterVar to starts from 0.
......
...@@ -13,7 +13,7 @@ namespace tvm { ...@@ -13,7 +13,7 @@ namespace tvm {
namespace ir { namespace ir {
// If expression is touched by var. // If expression is touched by var.
class ExprTouched : public IRVisitor { class ExprTouched final : public IRVisitor {
public: public:
explicit ExprTouched(const std::unordered_set<const Variable*> &touched) explicit ExprTouched(const std::unordered_set<const Variable*> &touched)
: touched_var_(touched) {} : touched_var_(touched) {}
......
...@@ -12,7 +12,7 @@ namespace ir { ...@@ -12,7 +12,7 @@ namespace ir {
// inliner to inline a function // inliner to inline a function
// the result may not be SSA, // the result may not be SSA,
// ConvertSSA need to be applied after this pass // ConvertSSA need to be applied after this pass
class IRInline : public IRMutator { class IRInline final : public IRMutator {
public: public:
IRInline(FunctionRef f, Array<Var> args, Expr body) IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {} : f_(f), args_(args), body_(body) {}
......
...@@ -180,6 +180,31 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { ...@@ -180,6 +180,31 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
IRMutator* m = this;
Halide::Internal::Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
if (!bounds_changed) {
return s;
} else {
return Prefetch::make(op->func, op->value_index,
op->type, new_bounds);
}
}
Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt first = this->Mutate(op->first); Stmt first = this->Mutate(op->first);
Stmt rest = this->Mutate(op->rest); Stmt rest = this->Mutate(op->rest);
......
...@@ -174,7 +174,6 @@ void IRVisitor::Visit_(const Provide *op) { ...@@ -174,7 +174,6 @@ void IRVisitor::Visit_(const Provide *op) {
} }
void IRVisitor::Visit_(const Realize *op) { void IRVisitor::Visit_(const Realize *op) {
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) { for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min); this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent); this->Visit(op->bounds[i]->extent);
...@@ -184,6 +183,13 @@ void IRVisitor::Visit_(const Realize *op) { ...@@ -184,6 +183,13 @@ void IRVisitor::Visit_(const Realize *op) {
this->Visit(op->condition); this->Visit(op->condition);
} }
void IRVisitor::Visit_(const Prefetch *op) {
for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent);
}
}
void IRVisitor::Visit_(const Block *op) { void IRVisitor::Visit_(const Block *op) {
this->Visit(op->first); this->Visit(op->first);
this->Visit(op->rest); this->Visit(op->rest);
......
...@@ -42,7 +42,7 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) { ...@@ -42,7 +42,7 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
// Rule: // Rule:
// - the range should not be const // - the range should not be const
// - there exist a condition expression in the scope that use the var // - there exist a condition expression in the scope that use the var
class CandidateSelector : public IRVisitor { class CandidateSelector final : public IRVisitor {
public: public:
using VarIsUsed = bool; using VarIsUsed = bool;
CandidateSelector() {} CandidateSelector() {}
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
class ThreadAllreduceBuilder : public IRMutator { class ThreadAllreduceBuilder final : public IRMutator {
public: public:
explicit ThreadAllreduceBuilder(int warp_size) explicit ThreadAllreduceBuilder(int warp_size)
: warp_size_(warp_size) {} : warp_size_(warp_size) {}
......
...@@ -31,7 +31,7 @@ using namespace storage; ...@@ -31,7 +31,7 @@ using namespace storage;
// The storage need to be kept alive between allocate and last access. // The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate. // The free point is only inserted at the same scope of allocate.
// //
class StorageAccessPatternFinder : public IRVisitor { class StorageAccessPatternFinder final : public IRVisitor {
public: public:
// Get linear access pattern. // Get linear access pattern.
std::vector<StmtEntry> GetLinearSeq(const Stmt& s) { std::vector<StmtEntry> GetLinearSeq(const Stmt& s) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment