Commit 01cbc61a by Jian Weng Committed by Tianqi Chen

[API] Prefetch schedule supported (#258)

* prefetch interface added

* prefetch python comments modified. prefetch info data structure maintained.

* start injecting prefetches. first step (domain touch) implemented.

* domain touch tested.

* Prefetch ir_mutator and ir_visitor dispatch registered.

* modify domain touched from passing a func_ref to passing a tensor

* modify domain touched from passing a func_ref to passing a tensor

* modify Tensor copy to Tensor ref

* temp commit for rebase

* debug info removed, typo fixed, ready to rebase

* prefetch flatten test add!

* roll back builtin functions to side effect functions

* lint error fixed!

* add cache line size to storage flatten argument

* forgot modifications add

* change code style to dmlc-like; get rid of can_prove, use manually compute instead

* python lint error fixed

* modify instrinsic name to pass tests

* [TEST] get rid of str(), replace them by accessing attributes

* change map to list comprehension

* redundant numpy import removed
parent 7b6427e3
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
#include "./expr.h" #include "./expr.h"
namespace tvm { namespace tvm {
class Tensor;
/*! \brief namespace of arithmetic */ /*! \brief namespace of arithmetic */
namespace arith { namespace arith {
/*! /*!
...@@ -255,6 +258,16 @@ IntSet DeduceBound(Expr v, Expr cond, ...@@ -255,6 +258,16 @@ IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& relax_map); const std::unordered_map<const Variable*, IntSet>& relax_map);
/*! /*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
* \param body The given statement.
* \param tensor The name of the calls or provides.
* \param consider_calls If calls (read) are considered.
* \param consider_provides If provides (write) are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
/*!
* \brief Evaluate the expression with modular analysis * \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated. * \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables. * \param mod_map Map of modular statistics of known variables.
......
...@@ -169,10 +169,12 @@ Stmt Inline(Stmt stmt, ...@@ -169,10 +169,12 @@ Stmt Inline(Stmt stmt,
* \param stmt The stmt to be trasnformed. * \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external * \param extern_buffer Map specifies external
* buffer assignment of input and outputs. * buffer assignment of input and outputs.
* \param cache_line_size The size of CPU cache line.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt StorageFlatten(Stmt stmt, Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer); Map<Tensor, Buffer> extern_buffer,
int cache_line_size);
/*! /*!
* \brief Remove No Op from the Stmt. * \brief Remove No Op from the Stmt.
...@@ -223,6 +225,13 @@ Stmt VectorizeLoop(Stmt stmt); ...@@ -223,6 +225,13 @@ Stmt VectorizeLoop(Stmt stmt);
Stmt InjectVirtualThread(Stmt stmt); Stmt InjectVirtualThread(Stmt stmt);
/*! /*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statment to be transformed.
* \return Transformed stmt.
*/
Stmt InjectPrefetch(Stmt stmt);
/*!
* \brief Rewrite storage allocation pattern. * \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope. * Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make * Trying to share space between allocations to make
......
...@@ -181,6 +181,14 @@ class Stage : public NodeRef { ...@@ -181,6 +181,14 @@ class Stage : public NodeRef {
*/ */
Stage& parallel(IterVar var); // NOLINT(*) Stage& parallel(IterVar var); // NOLINT(*)
/*! /*!
* \brief Fetch data in advance.
* \param domain the tensor to be prefetched
* \param var the iteration point at which to apply prefetching
* \param offset the number of iterations be to fetched in advance
* \return reference to self
*/
Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
/*!
* \brief whether the stage has been scheduled. * \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled. * \return whether the stage has been scheduled.
*/ */
......
...@@ -185,7 +185,8 @@ def lower(sch, ...@@ -185,7 +185,8 @@ def lower(sch,
sch = sch.normalize() sch = sch.normalize()
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.InjectPrefetch(stmt)
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
if not simple_mode: if not simple_mode:
stmt = ir_pass.LoopPartition(stmt) stmt = ir_pass.LoopPartition(stmt)
......
...@@ -422,4 +422,18 @@ class Stage(NodeBase): ...@@ -422,4 +422,18 @@ class Stage(NodeBase):
""" """
_api_internal._StageParallel(self, var) _api_internal._StageParallel(self, var)
def prefetch(self, tensor, var, offset):
"""Prefetch the specified variable
Parameters
----------
tensor : Tensor
The tensor to be prefetched
var : IterVar
The loop point at which the prefetching is applied
offset : Expr
The number of iterations to be prefetched before actual execution
"""
_api_internal._StagePrefetch(self, tensor, var, offset)
_init_api("tvm.schedule") _init_api("tvm.schedule")
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/arithmetic.h> #include <tvm/tensor.h>
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -38,6 +38,13 @@ TVM_REGISTER_API("arith.DeduceBound") ...@@ -38,6 +38,13 @@ TVM_REGISTER_API("arith.DeduceBound")
args[3].operator Map<Var, IntSet>()); args[3].operator Map<Var, IntSet>());
}); });
TVM_REGISTER_API("arith.DomainTouched")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DomainTouched(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_API("_IntervalSetGetMin") TVM_REGISTER_API("_IntervalSetGetMin")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().min(); *ret = args[0].operator IntSet().min();
......
...@@ -358,6 +358,12 @@ TVM_REGISTER_API("_StageParallel") ...@@ -358,6 +358,12 @@ TVM_REGISTER_API("_StageParallel")
.parallel(args[1]); .parallel(args[1]);
}); });
TVM_REGISTER_API("_StagePrefetch")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage()
.prefetch(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_ScheduleNormalize") TVM_REGISTER_API("_ScheduleNormalize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule() *ret = args[0].operator Schedule()
......
...@@ -86,7 +86,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") ...@@ -86,7 +86,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA); REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten); REGISTER_PASS3(StorageFlatten);
REGISTER_PASS1(VectorizeLoop); REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop); REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync); REGISTER_PASS2(StorageSync);
...@@ -95,6 +95,7 @@ REGISTER_PASS2(BindDeviceType); ...@@ -95,6 +95,7 @@ REGISTER_PASS2(BindDeviceType);
REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(StorageRewrite); REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS1(LoopPartition); REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp); REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline); REGISTER_PASS2(SplitPipeline);
......
/*!
* Copyright (c) 2017 by Contributors
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/tensor.h>
#include <tvm/api_registry.h>
#include <unordered_set>
#include <unordered_map>
namespace tvm {
namespace arith {
using namespace ir;
// Find Read region of the tensor in the stmt.
class FuncTouchedDomain final : public IRVisitor {
public:
FuncTouchedDomain(const Tensor &tensor, bool consider_calls, bool consider_provides)
: tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {}
Domain Find(const Stmt& stmt) {
this->Visit(stmt);
Domain ret;
Range none;
for (size_t i = 0; i < bounds_.size(); ++i) {
ret.push_back(arith::Union(bounds_[i]).cover_range(none));
}
return ret;
}
void Visit_(const For *op) final {
const Variable* var = op->loop_var.get();
dom_map_[var] = IntSet::range(
Range::make_by_min_extent(op->min, op->extent));
IRVisitor::Visit_(op);
dom_map_.erase(var);
}
void Visit_(const LetStmt* op) final {
dom_map_[op->var.get()] =
arith::EvalSet(op->value, dom_map_);
IRVisitor::Visit_(op);
dom_map_.erase(op->var.get());
}
/* TODO: Thread extent unitest not generated.*/
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const Variable* var = thread_axis->var.get();
dom_map_[var] = IntSet::range(Range(make_zero(op->value.type()), op->value));
IRVisitor::Visit_(op);
dom_map_.erase(var);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Call* op) final {
if (consider_calls_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
}
IRVisitor::Visit_(op);
}
void Visit_(const Provide* op) final {
if (consider_provides_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
}
IRVisitor::Visit_(op);
}
private:
void Touch(const Array<Expr>& args) {
if (args.size() > bounds_.size()) {
bounds_.resize(args.size());
}
for (size_t i = 0; i < args.size(); ++i) {
bounds_[i].emplace_back(EvalSet(args[i], dom_map_));
}
}
const Tensor &tensor_;
bool consider_calls_, consider_provides_;
std::vector<std::vector<IntSet> > bounds_;
std::unordered_map<const Variable*, IntSet> dom_map_;
};
Domain DomainTouched(Stmt stmt, const Tensor &tensor, bool consider_calls, bool consider_provides) {
return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt);
}
} // namespace arith
} // namespace tvm
...@@ -807,7 +807,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { ...@@ -807,7 +807,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
llvm::Function* f = llvm::Intrinsic::getDeclaration( llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, arg_types); module_.get(), id, arg_types);
return builder_->CreateCall(f, arg_values); return builder_->CreateCall(f, arg_values);
} else if (op->is_intrinsic("llvm_buildin")) { } else if (op->is_intrinsic("llvm_builtin")) {
std::vector<llvm::Value*> arg_values; std::vector<llvm::Value*> arg_values;
for (size_t i = 1; i < op->args.size(); ++i) { for (size_t i = 1; i < op->args.size(); ++i) {
llvm::Value* v = MakeValue(op->args[i]); llvm::Value* v = MakeValue(op->args[i]);
......
...@@ -28,9 +28,12 @@ inline void DispatchLLVMBuildin(const TVMArgs& targs, TVMRetValue* rv) { ...@@ -28,9 +28,12 @@ inline void DispatchLLVMBuildin(const TVMArgs& targs, TVMRetValue* rv) {
cargs.push_back(arg); cargs.push_back(arg);
} }
*rv = Call::make( *rv = Call::make(
call->type, "llvm_buildin", cargs, Call::Intrinsic); call->type, "llvm_builtin", cargs, Call::Intrinsic);
} }
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
.set_body(DispatchLLVMBuildin<::llvm::Intrinsic::prefetch>);
template<unsigned id> template<unsigned id>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0]; Expr e = targs[0];
...@@ -46,8 +49,20 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { ...@@ -46,8 +49,20 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
call->type, "llvm_intrin", cargs, Call::PureIntrinsic); call->type, "llvm_intrin", cargs, Call::PureIntrinsic);
} }
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.__buildin_prefetch") template<unsigned id>
.set_body(DispatchLLVMBuildin<::llvm::Intrinsic::prefetch>); inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(UIntImm::make(UInt(32), id));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = Call::make(
call->type, "llvm_intrin", cargs, Call::Intrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>);
......
...@@ -95,11 +95,11 @@ MakeLoopNest(const Stage& stage, ...@@ -95,11 +95,11 @@ MakeLoopNest(const Stage& stage,
<< "Cannot prefetch on trivial loop with extent=1"; << "Cannot prefetch on trivial loop with extent=1";
CHECK_EQ(it_attr->prefetch_data.size(), CHECK_EQ(it_attr->prefetch_data.size(),
it_attr->prefetch_offset.size()); it_attr->prefetch_offset.size());
for (size_t i = 0; i < it_attr->prefetch_data.size(); ++i) { for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(it_attr->prefetch_data[i], AttrStmt::make(it_attr->prefetch_data[j],
ir::attr::prefetch_scope, ir::attr::prefetch_scope,
it_attr->prefetch_offset[i], no_op)); it_attr->prefetch_offset[j], no_op));
} }
} }
} else if (bind_iv->thread_tag == "vthread") { } else if (bind_iv->thread_tag == "vthread") {
......
/*!
* Copyright (c) 2017 by Contributors
* \file inject_prefetch.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <unordered_set>
namespace tvm {
namespace ir {
using arith::IntSet;
using arith::DomainTouched;
using Halide::Internal::Region;
class PrefetchInjector : public IRMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt ret = IRMutator::Mutate_(op, s);
op = ret.as<AttrStmt>();
if (op && op->attr_key == attr::prefetch_scope) {
Tensor ts(op->node.node_);
CHECK_NE(loop_nest_.size(), 0U);
Domain domain = DomainTouched(op->body, ts, true, false);
Region region;
auto iter_var = loop_nest_.back().get();
vectorized_[iter_var] = IntSet::single_point(loop_nest_.back() + op->value);
for (Range r : domain) {
if (!r.defined()) {
LOG(WARNING) << "Cannot decide prefetch region for " << ts;
return op->body;
}
Range res(EvalSet(r, vectorized_).cover_range(none));
region.push_back(Range::make_by_min_extent(res->min, res->extent));
}
vectorized_.erase(iter_var);
Stmt prefetch = Prefetch::make(ts->op, ts->value_index, ts->dtype, region);
return Block::make(prefetch, op->body);
}
return ret;
}
Stmt Mutate_(const For* op, const Stmt& s) final {
auto &var = op->loop_var;
loop_nest_.push_back(var);
if (op->for_type == ForType::Vectorized) {
vectorized_[var.get()] = IntSet::interval(op->min, (op->min + op->extent) - 1);
}
Stmt ret = IRMutator::Mutate_(op, s);
if (op->for_type == ForType::Vectorized) {
vectorized_.erase(var.get());
}
loop_nest_.pop_back();
return ret;
}
private:
std::vector<VarExpr> loop_nest_;
std::unordered_map<const Variable *, IntSet> vectorized_;
static const Range none;
};
const Range PrefetchInjector::none;
Stmt InjectPrefetch(Stmt stmt) {
return PrefetchInjector().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
...@@ -265,7 +265,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -265,7 +265,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(Provide) .DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize) .DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Block) .DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Evaluate); .DISPATCH_TO_MUTATE_STMT(Evaluate)
.DISPATCH_TO_MUTATE_STMT(Prefetch);
// Mutate Expr // Mutate Expr
......
...@@ -255,7 +255,8 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -255,7 +255,8 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(IntImm) .DISPATCH_TO_VISIT(IntImm)
.DISPATCH_TO_VISIT(UIntImm) .DISPATCH_TO_VISIT(UIntImm)
.DISPATCH_TO_VISIT(FloatImm) .DISPATCH_TO_VISIT(FloatImm)
.DISPATCH_TO_VISIT(StringImm); .DISPATCH_TO_VISIT(StringImm)
.DISPATCH_TO_VISIT(Prefetch);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
...@@ -20,16 +21,18 @@ namespace ir { ...@@ -20,16 +21,18 @@ namespace ir {
using Halide::Internal::Region; using Halide::Internal::Region;
using runtime::StorageScope; using runtime::StorageScope;
using runtime::ThreadScope; using runtime::ThreadScope;
using intrinsic::tvm_address_of;
class StorageFlattener : public IRMutator { class StorageFlattener : public IRMutator {
public: public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer) { explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer, int cache_line_size) {
for (auto kv : extern_buffer) { for (auto kv : extern_buffer) {
BufferEntry e; BufferEntry e;
e.buffer = kv.second; e.buffer = kv.second;
e.external = true; e.external = true;
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e; buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
} }
cache_line_size_ = cache_line_size;
} }
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
...@@ -169,6 +172,62 @@ class StorageFlattener : public IRMutator { ...@@ -169,6 +172,62 @@ class StorageFlattener : public IRMutator {
} }
} }
Stmt Mutate_(const Prefetch *op, const Stmt &s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Prefetch>();
CHECK(op != nullptr);
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
CHECK_EQ(e.buffer->shape.size(), op->bounds.size())
<< "Prefetch dim should be the same as buffer dim";
int block_size = 1,
elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(),
shape = 0;
int starts = op->bounds.size() - 1;
while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape)
&& elem_cnt >= block_size * shape) {
block_size *= shape;
starts--;
}
Expr stride(elem_cnt / block_size);
Array<Expr> args;
std::vector<VarExpr> vars;
for (int i = op->bounds.size() - 1; i > starts; --i) {
args.push_back(op->bounds[i]->min);
}
auto &func_name = op->func->func_name();
vars.push_back(VarExpr("prefetch." + func_name + "." + std::to_string(starts), Int(32)));
args.push_back(op->bounds[starts]->min + stride * vars.back());
for (int i = starts - 1; i >= 0; --i) {
vars.push_back(VarExpr("prefetch." + func_name + "." + std::to_string(i), Int(32)));
args.push_back(vars.back() + op->bounds[i]->min);
}
for (int i = starts; i >= 0; --i) {
if (i < starts) {
stmt = For::make(
vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::Host, stmt);
} else {
Expr load = e.buffer.MakeLoad(e.RelIndex(args));
Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
stmt = Evaluate::make(prefetch);
Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::Host, stmt);
}
}
return stmt;
}
private: private:
// Start bind // Start bind
Stmt HandleBufferBindScope(const AttrStmt* op) { Stmt HandleBufferBindScope(const AttrStmt* op) {
...@@ -252,11 +311,14 @@ class StorageFlattener : public IRMutator { ...@@ -252,11 +311,14 @@ class StorageFlattener : public IRMutator {
std::unordered_map<const Node*, std::string> storage_scope_; std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope. // The current thread scope.
std::vector<ThreadScope> curr_thread_scope_; std::vector<ThreadScope> curr_thread_scope_;
// The size of cacheline
int cache_line_size_;
}; };
Stmt StorageFlatten(Stmt stmt, Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer) { Map<Tensor, Buffer> extern_buffer,
stmt = StorageFlattener(extern_buffer).Mutate(stmt); int cache_line_size) {
stmt = StorageFlattener(extern_buffer, cache_line_size).Mutate(stmt);
return stmt; return stmt;
} }
......
...@@ -337,6 +337,24 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*) ...@@ -337,6 +337,24 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*)
return *this; return *this;
} }
Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) {
StageNode *self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
auto it = self->iter_var_attrs.find(var);
std::shared_ptr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) {
n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->());
} else {
n = std::make_shared<IterVarAttrNode>();
}
n->prefetch_data.push_back(tensor);
n->prefetch_offset.push_back(offset);
self->iter_var_attrs.Set(var, IterVarAttr(n));
return *this;
}
Stage CopyStage(const Stage& s) { Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n = std::shared_ptr<StageNode> n =
std::make_shared<StageNode>(*s.operator->()); std::make_shared<StageNode>(*s.operator->());
......
...@@ -13,7 +13,7 @@ def lower(s, args, name="mydot"): ...@@ -13,7 +13,7 @@ def lower(s, args, name="mydot"):
s = s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 16)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0, True)
......
import tvm
def test_domain_touched():
i = tvm.var('i')
j = tvm.var('j')
n = tvm.convert(100)
m = tvm.var('m')
a = tvm.placeholder((n, m), name = 'a')
b = tvm.placeholder((n, m), name = 'b')
ir = tvm.make.For(
i, 0, n, 0, 0,
tvm.make.For(j, 0, m, 0, 0,
tvm.make.Provide(
a.op,
0,
tvm.make.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) +
tvm.make.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0),
[i, j]
)
)
)
a_domain_r = tvm.arith.DomainTouched(ir, a, True, False)
assert a_domain_r[0].min.value == -1
assert a_domain_r[0].extent.value == 100
assert a_domain_r[1].min.value == -1
assert a_domain_r[1].extent.name == 'm'
a_domain_w = tvm.arith.DomainTouched(ir, a, False, True)
assert a_domain_w[0].min.value == 0
assert a_domain_w[0].extent.value == 100
assert a_domain_w[1].min.value == 0
assert a_domain_w[1].extent.name == 'm'
a_domain_rw= tvm.arith.DomainTouched(ir, a, True, True)
assert a_domain_rw[0].min.value == -1
assert a_domain_rw[0].extent.value == 101
assert a_domain_rw[1].min.value == -1
assert isinstance(a_domain_rw[1].extent, tvm.expr.Add)
assert a_domain_rw[1].extent.a.name == 'm'
assert a_domain_rw[1].extent.b.value == 1
b_domain_r = tvm.arith.DomainTouched(ir, b, True, False)
assert b_domain_r
assert b_domain_r[0].min.value == -1
assert b_domain_r[0].extent.value == 100
assert b_domain_r[1].min.value == 1
assert b_domain_r[1].extent.name == 'm'
b_domain_w = tvm.arith.DomainTouched(ir, b, False, True)
assert isinstance(b_domain_w, tvm.container.Array)
assert len(b_domain_w) == 0
if __name__ == "__main__":
test_domain_touched()
...@@ -28,7 +28,7 @@ def test_add_pipeline(): ...@@ -28,7 +28,7 @@ def test_add_pipeline():
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.decl_buffer(C.shape, C.dtype, name='C') Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
......
...@@ -12,7 +12,7 @@ def test_llvm_intrin(): ...@@ -12,7 +12,7 @@ def test_llvm_intrin():
] ]
ib.emit(tvm.make.Evaluate( ib.emit(tvm.make.Evaluate(
tvm.make.Call( tvm.make.Call(
"int32", "__buildin_prefetch", args, tvm.expr.Call.Intrinsic, None, 0))) "int32", "prefetch", args, tvm.expr.Call.Intrinsic, None, 0)))
body = ib.get() body = ib.get()
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True) func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm") fcode = tvm.build(func, None, "llvm")
......
...@@ -20,7 +20,7 @@ def lower(sch, args): ...@@ -20,7 +20,7 @@ def lower(sch, args):
bounds = tvm.schedule.InferBound(sch) bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds) stmt = tvm.schedule.ScheduleOps(sch, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.VectorizeLoop(stmt) stmt = tvm.ir_pass.VectorizeLoop(stmt)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
......
...@@ -15,7 +15,7 @@ def test_makeapi(): ...@@ -15,7 +15,7 @@ def test_makeapi():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.decl_buffer(C.shape, C.dtype, name='C') Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
num_unpacked_args = 2 num_unpacked_args = 2
f = tvm.ir_pass.MakeAPI( f = tvm.ir_pass.MakeAPI(
......
...@@ -12,7 +12,7 @@ def lower(s, args): ...@@ -12,7 +12,7 @@ def lower(s, args):
s.normalize() s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
return stmt return stmt
......
...@@ -16,8 +16,22 @@ def test_flatten2(): ...@@ -16,8 +16,22 @@ def test_flatten2():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2') A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
def test_flatten_prefetch():
A = tvm.placeholder((25, 100, 4), name = 'A')
_A= tvm.decl_buffer(A.shape, A.dtype, name = 'A');
i = tvm.var('i')
j = tvm.var('j')
region = [tvm.make.range_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]]
stmt = tvm.make.Prefetch(A.op, 0, A.dtype, region)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: _A}, 64)
stmt = tvm.ir_pass.Simplify(stmt)
assert stmt.extent.value == 2
assert isinstance(stmt.body, tvm.stmt.For)
assert stmt.body.extent.value == 2
if __name__ == "__main__": if __name__ == "__main__":
test_flatten2() test_flatten2()
test_flatten_prefetch()
...@@ -15,7 +15,7 @@ def test_storage_share(): ...@@ -15,7 +15,7 @@ def test_storage_share():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt) stmt = tvm.ir_pass.StorageRewrite(stmt)
...@@ -51,7 +51,7 @@ def test_storage_share_gpu(): ...@@ -51,7 +51,7 @@ def test_storage_share_gpu():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A[0].shape, A[0].dtype, name='A') Ab = tvm.decl_buffer(A[0].shape, A[0].dtype, name='A')
Bb = tvm.decl_buffer(A[0].shape, A[0].dtype, name='B') Bb = tvm.decl_buffer(A[0].shape, A[0].dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt) stmt = tvm.ir_pass.StorageRewrite(stmt)
......
...@@ -19,7 +19,7 @@ def test_storage_sync(): ...@@ -19,7 +19,7 @@ def test_storage_sync():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2') A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True) f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.ir_pass.SplitHostDevice(f) flist = tvm.ir_pass.SplitHostDevice(f)
f = flist[1] f = flist[1]
......
...@@ -19,7 +19,7 @@ def test_virtual_thread(): ...@@ -19,7 +19,7 @@ def test_virtual_thread():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2') A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.InjectVirtualThread(stmt) stmt = tvm.ir_pass.InjectVirtualThread(stmt)
print(stmt) print(stmt)
......
...@@ -14,7 +14,7 @@ def lower(s, args, name): ...@@ -14,7 +14,7 @@ def lower(s, args, name):
s = s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.SplitPipeline(stmt, True) stmt = tvm.ir_pass.SplitPipeline(stmt, True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment