Commit 1400edac by Tianqi Chen Committed by GitHub

[IR] Include PrefetchIR (#189)

parent eaf0fde3
Subproject commit efe5b5cc3c89da5d5e39570f6776d39d8acacacc
Subproject commit 41fe60a76fe6e5669540acf1ef3595bc38025157
......@@ -158,6 +158,11 @@ constexpr const char* device_context_type = "device_context_type";
constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope";
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
*/
constexpr const char* prefetch_scope = "prefetch_scope";
/*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
......@@ -371,6 +376,7 @@ using Halide::Internal::Provide;
using Halide::Internal::Allocate;
using Halide::Internal::Free;
using Halide::Internal::Realize;
using Halide::Internal::Prefetch;
using Halide::Internal::Block;
using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate;
......
......@@ -17,6 +17,9 @@ namespace ir {
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
*
* This helps you to avoid to book-keep return value of Visitor via state,
* which can cause bugs easily when state is incorrectly maintained.
*
* \code
* // A functor that set variable to b. and calculate results.
* class MyExprFunctor
......@@ -223,6 +226,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Node* op, Args ...) {
......@@ -245,6 +249,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Prefetch);
IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable;
......
......@@ -66,6 +66,7 @@ class IRMutator {
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Prefetch* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& s);
......
......@@ -116,6 +116,7 @@ class IRVisitor {
virtual void Visit_(const ProducerConsumer* op);
virtual void Visit_(const Provide* op);
virtual void Visit_(const Realize* op);
virtual void Visit_(const Prefetch* op);
virtual void Visit_(const Block* op);
virtual void Visit_(const Evaluate* op);
virtual void Visit_(const IntImm* op);
......
......@@ -461,10 +461,16 @@ class IterVarAttrNode : public Node {
IterVarType iter_type{kDataPar};
/*! \brief The thread this iter Var binds, can be null */
IterVar bind_thread;
/*! \brief List of tensor to be prefetched in this loop */
Array<Tensor> prefetch_data;
/*! \brief The offset used in each prefetch */
Array<Expr> prefetch_offset;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
v->Visit("bind_thread", &bind_thread);
v->Visit("prefetch_data", &prefetch_data);
v->Visit("prefetch_offset", &prefetch_offset);
}
static constexpr const char* _type_key = "IterVarAttr";
......
......@@ -13,3 +13,4 @@ There can be internal header files within each module that sit in src.
- pass The optimization pass on the IR structure
- codegen The code generator.
- runtime Minimum runtime related codes
- contrib Contrib extension libraries
......@@ -212,7 +212,7 @@ void BoundDeducer::Deduce() {
success = false;
return;
}
// get the sign of every subexpr
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
Visit(expr_);
......
......@@ -55,14 +55,18 @@ MakeLoopNest(const Stage& stage,
// Mark the iter var in the IR, to remember the point
if (bind_iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial;
IterVarAttr it_attr;
if (stage->iter_var_attrs.count(iv)) {
switch (stage->iter_var_attrs[iv]->iter_type) {
it_attr = stage->iter_var_attrs[iv];
}
if (it_attr.defined()) {
switch (it_attr->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break;
case kVectorized: for_type = ForType::Vectorized; break;
case kParallelized: for_type = ForType::Parallel; break;
case kDataPar: break;
default: LOG(FATAL) << "Unknown iter type"
<< stage->iter_var_attrs[iv]->iter_type
<< it_attr->iter_type
<< " in the iter_var_attrs";
}
}
......@@ -85,6 +89,18 @@ MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op));
}
if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
CHECK(!is_one(dom->extent))
<< "Cannot prefetch on trivial loop with extent=1";
CHECK_EQ(it_attr->prefetch_data.size(),
it_attr->prefetch_offset.size());
for (size_t i = 0; i < it_attr->prefetch_data.size(); ++i) {
nest[i + 1].emplace_back(
AttrStmt::make(it_attr->prefetch_data[i],
ir::attr::prefetch_scope,
it_attr->prefetch_offset[i], no_op));
}
}
} else if (bind_iv->thread_tag == "vthread") {
// virtual thread
// Always restrict threaded IterVar to starts from 0.
......
......@@ -13,7 +13,7 @@ namespace tvm {
namespace ir {
// If expression is touched by var.
class ExprTouched : public IRVisitor {
class ExprTouched final : public IRVisitor {
public:
explicit ExprTouched(const std::unordered_set<const Variable*> &touched)
: touched_var_(touched) {}
......
......@@ -12,7 +12,7 @@ namespace ir {
// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class IRInline : public IRMutator {
class IRInline final : public IRMutator {
public:
IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {}
......
......@@ -180,6 +180,31 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
}
}
Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
IRMutator* m = this;
Halide::Internal::Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
if (!bounds_changed) {
return s;
} else {
return Prefetch::make(op->func, op->value_index,
op->type, new_bounds);
}
}
Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt first = this->Mutate(op->first);
Stmt rest = this->Mutate(op->rest);
......
......@@ -174,7 +174,6 @@ void IRVisitor::Visit_(const Provide *op) {
}
void IRVisitor::Visit_(const Realize *op) {
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent);
......@@ -184,6 +183,13 @@ void IRVisitor::Visit_(const Realize *op) {
this->Visit(op->condition);
}
void IRVisitor::Visit_(const Prefetch *op) {
for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent);
}
}
void IRVisitor::Visit_(const Block *op) {
this->Visit(op->first);
this->Visit(op->rest);
......
......@@ -42,7 +42,7 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
// Rule:
// - the range should not be const
// - there exist a condition expression in the scope that use the var
class CandidateSelector : public IRVisitor {
class CandidateSelector final : public IRVisitor {
public:
using VarIsUsed = bool;
CandidateSelector() {}
......
......@@ -14,7 +14,7 @@
namespace tvm {
namespace ir {
class ThreadAllreduceBuilder : public IRMutator {
class ThreadAllreduceBuilder final : public IRMutator {
public:
explicit ThreadAllreduceBuilder(int warp_size)
: warp_size_(warp_size) {}
......
......@@ -31,7 +31,7 @@ using namespace storage;
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
//
class StorageAccessPatternFinder : public IRVisitor {
class StorageAccessPatternFinder final : public IRVisitor {
public:
// Get linear access pattern.
std::vector<StmtEntry> GetLinearSeq(const Stmt& s) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment