Commit b8f0ec50 by Tianqi Chen Committed by GitHub

[LANG/PASS] InjectVirtualThread (#38)

parent 526ff04c
......@@ -49,6 +49,30 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
/*!
* \brief Mark scope of iteration variable, used by Schedule.
*/
constexpr const char* scope = "scope";
/*!
* \brief Mark launching extent of thread, used by device API.
*/
constexpr const char* thread_extent = "thread_extent";
/*!
* \brief Mark launching of a virtual thread.
*/
constexpr const char* virtual_thread = "virtual_thread";
/*!
* \brief Mark storage scope of buffers
*/
constexpr const char* storage_scope = "storage_scope";
/*!
* \brief Mark storage scope of realizations
*/
constexpr const char* realize_scope = "realize_scope";
} // namespace attr
/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
// Most of the intrinsics is to enab
......
......@@ -63,6 +63,7 @@ class IRMutator {
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
......
......@@ -100,6 +100,7 @@ Stmt Inline(Stmt stmt,
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);
......@@ -108,16 +109,35 @@ Stmt StorageFlatten(Stmt stmt,
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
* \return Transformed stmt.
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);
/*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
* \return Transformed stmt.
*/
Stmt VectorizeLoop(Stmt stmt);
/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed.
* \return Transformed stmt.
*/
Stmt InjectVirtualThread(Stmt stmt);
/*!
* \brief Lift storage allocation to relevant outpost location
*
* Only do this after vectorization and virtual thread injection completes.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt LiftAllocate(Stmt stmt);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
......
......@@ -70,6 +70,8 @@ def build(sch,
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.LiftAllocate(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
......
......@@ -67,6 +67,8 @@ REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(LiftAllocate);
REGISTER_PASS1(InjectVirtualThread);
} // namespace ir
} // namespace tvm
......@@ -288,7 +288,8 @@ class Canonical::Internal : public IRMutator {
}
// AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->type_key == "thread_extent") {
if (op->type_key == attr::thread_extent ||
op->type_key == attr::virtual_thread) {
++level_counter_;
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
......
......@@ -743,7 +743,7 @@ void CodeGenC::PrintStmt(const Allocate* op) {
}
void CodeGenC::PrintStmt(const AttrStmt* op) {
if (op->type_key == "scope") {
if (op->type_key == ir::attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
if (!var_idmap_.count(iv->var.get())) {
......@@ -756,7 +756,7 @@ void CodeGenC::PrintStmt(const AttrStmt* op) {
stream << ";\n";
}
}
} else if (op->type_key == "storage_scope") {
} else if (op->type_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
......
......@@ -9,6 +9,7 @@
#include <string>
#include "./codegen_cuda.h"
#include "./codegen_stack_vm.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/cuda/cuda_common.h"
#include "../runtime/cuda/cuda_module.h"
......@@ -22,6 +23,17 @@ std::string CodeGenCUDA::Compile(
return CodeGenC::Compile(f, output_ssa);
}
void CodeGenCUDA::PrintStmt(const ir::For* op) {
int ext;
CHECK(is_zero(op->min));
if (arith::GetConstInt(op->extent, &ext) &&
ext <= max_auto_unroll_) {
PrintIndent();
stream << "#pragma unroll\n";
}
CodeGenC::PrintStmt(op);
}
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
......
......@@ -27,6 +27,7 @@ class CodeGenCUDA : public CodeGenC {
bool output_ssa);
// override behavior
void PrintStmt(const ir::For* op) final;
void PrintStorageSync(const std::string& sync) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
......@@ -37,6 +38,11 @@ class CodeGenCUDA : public CodeGenC {
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final;
private:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int max_auto_unroll_{8};
};
} // namespace codegen
......
/*!
* Copyright (c) 2017 by Contributors
* \file inject_virtual_thread.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
// If expression is touched by var.
class ExprTouched : public IRVisitor {
public:
explicit ExprTouched(const std::unordered_set<const Variable*> &touched)
: touched_var_(touched) {}
void Visit(const NodeRef& n) final {
// early stopping
if (expr_touched_) return;
IRVisitor::Visit(n);
}
void Visit_(const Load *op) final {
HandleUseVar(op->buffer_var.get());
IRVisitor::Visit_(op);
}
void Visit_(const Variable *op) final {
HandleUseVar(op);
}
void HandleUseVar(const Variable* var) {
auto it = touched_var_.find(var);
if (it != touched_var_.end()) {
expr_touched_ = true;
}
// rember the used vars
// in case the var get touched later in a loop.
if (!expr_touched_) {
used_vars_.push_back(var);
}
}
// the fields.
bool expr_touched_{false};
std::vector<const Variable*> used_vars_;
const std::unordered_set<const Variable*>& touched_var_;
};
// Analyze if the buffers are invariant to value of var
class VarTouchedAnalysis : public IRVisitor {
public:
void Visit_(const LetStmt *op) {
ExprTouched tc(touched_var_);
tc.Visit(op->value);
Record(op->var.get(), tc);
this->Visit(op->body);
}
void Visit_(const Store *op) {
ExprTouched tc(touched_var_);
tc.Visit(op->value);
tc.Visit(op->index);
Record(op->buffer_var.get(), tc);
}
void Visit_(const For *op) {
ExprTouched tc(touched_var_);
tc.Visit(op->min);
tc.Visit(op->extent);
Record(op->loop_var.get(), tc);
this->Visit(op->body);
}
void Visit_(const Allocate *op) {
ExprTouched tc(touched_var_);
for (size_t i = 0; i < op->extents.size(); ++i) {
tc.Visit(op->extents[i]);
}
tc.Visit(op->condition);
if (op->new_expr.defined()) {
tc.Visit(op->new_expr);
}
Record(op->buffer_var.get(), tc);
this->Visit(op->body);
}
void Record(const Variable* var,
const ExprTouched& tc) {
if (touched_var_.count(var)) return;
if (tc.expr_touched_) {
touched_var_.insert(var);
} else {
for (const Variable* r : tc.used_vars_) {
affect_[r].push_back(var);
}
}
}
std::unordered_set<const Variable*>
TouchedVar(const Stmt& stmt,
const Variable* var) {
touched_var_.insert(var);
this->Visit(stmt);
// do a DFS to push affect around dependency.
std::vector<const Variable*> pending(
touched_var_.begin(), touched_var_.end());
while (!pending.empty()) {
const Variable* v = pending.back();
pending.pop_back();
for (const Variable* r : affect_[v]) {
if (!touched_var_.count(r)) {
touched_var_.insert(r);
pending.push_back(r);
}
}
}
return std::move(touched_var_);
}
private:
// Whether variable is touched by the thread variable.
std::unordered_set<const Variable*> touched_var_;
// x -> all the buffers x read from
std::unordered_map<const Variable*,
std::vector<const Variable*> > affect_;
};
// Inject virtual thread loop
// rewrite the buffer access pattern when necessary.
class VTInjector : public IRMutator {
public:
using IRMutator::Mutate;
// constructor
VTInjector(Var var,
int num_threads,
std::unordered_set<const Variable*> touched_var)
: var_(var), num_threads_(num_threads), touched_var_(touched_var) {
}
// Inject VTLoop when needed.
Stmt Mutate(Stmt stmt) final {
CHECK(!visit_touched_var_)
<< stmt->type_key() << stmt;
stmt = IRMutator::Mutate(stmt);
if (visit_touched_var_) {
if (!vt_loop_injected_) return InjectVTLoop(stmt, false);
visit_touched_var_ = false;
}
return stmt;
}
// Variable
Expr Mutate_(const Variable *op, const Expr& e) final {
if (touched_var_.count(op)) {
visit_touched_var_ = true;
}
return e;
}
Expr RewriteIndex(Expr index, Expr alloc_extent) const {
if (index_rewrite_strategy_ == 0) {
return index * num_threads_ + var_;
} else {
return index + var_ * alloc_extent;
}
}
// Load
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
}
auto it = touched_alloc_.find(op->buffer_var.get());
if (it != touched_alloc_.end()) {
return Load::make(op->type, op->buffer_var,
RewriteIndex(op->index, it->second));
} else {
return expr;
}
}
// Store
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
}
auto it = touched_alloc_.find(op->buffer_var.get());
if (it != touched_alloc_.end()) {
return Store::make(op->buffer_var,
op->value,
RewriteIndex(op->index, it->second));
} else {
return stmt;
}
}
// Attribute
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::scope) {
return Mutate(op->body);
} else {
Expr value = Mutate(op->value);
if (visit_touched_var_) {
return InjectVTLoop(s, true);
} else {
Stmt body = Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, value, body);
}
}
}
}
// LetStmt
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Expr value = this->Mutate(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
}
visit_touched_var_ = false;
Stmt body = Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
// For
Stmt Mutate_(const For* op, const Stmt& s) final {
CHECK(is_zero(op->min));
Expr extent = Mutate(op->extent);
if (visit_touched_var_ && !vt_loop_injected_) {
Stmt stmt = InjectVTLoop(s, true);
++max_loop_depth_;
return stmt;
}
visit_touched_var_ = false;
Stmt body = Mutate(op->body);
++max_loop_depth_;
if (extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
} else {
return For::make(
op->loop_var, op->min, extent, op->for_type, op->device_api, body);
}
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
Expr condition = this->Mutate(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
}
visit_touched_var_ = false;
CHECK_EQ(max_loop_depth_, 0);
Stmt then_case = this->Mutate(op->then_case);
Stmt else_case;
if (else_case.defined()) {
int temp = max_loop_depth_;
max_loop_depth_ = 0;
else_case = this->Mutate(op->else_case);
max_loop_depth_ = std::max(temp, max_loop_depth_);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
// Block
Stmt Mutate_(const Block* op, const Stmt& s) final {
CHECK_EQ(max_loop_depth_, 0);
Stmt first = this->Mutate(op->first);
int temp = max_loop_depth_;
max_loop_depth_ = 0;
Stmt rest = this->Mutate(op->rest);
max_loop_depth_ = std::max(max_loop_depth_, temp);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
}
// Allocate
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
if (op->new_expr.defined() && !vt_loop_injected_) {
return InjectVTLoop(s, true);
}
Expr condition = Mutate(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
}
bool changed = false;
Array<Expr> extents;
for (size_t i = 0; i < op->extents.size(); i++) {
Expr new_ext = Mutate(op->extents[i]);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
}
if (!new_ext.same_as(op->extents[i])) changed = true;
extents.push_back(new_ext);
}
visit_touched_var_ = false;
Stmt body;
if (touched_var_.count(op->buffer_var.get())) {
// place v on highest dimension.
Expr stride = extents[0];
for (size_t i = 1; i < extents.size(); ++i) {
stride = arith::ComputeExpr<Mul>(stride, extents[i]);
}
Array<Expr> other;
other.push_back(num_threads_);
for (Expr e : extents) {
other.push_back(e);
}
extents = other;
changed = true;
// mark this buffer get touched.
touched_alloc_[op->buffer_var.get()] = stride;
// Mutate the body.
body = Mutate(op->body);
} else {
// Mutate the body.
body = Mutate(op->body);
}
if (!changed &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
return s;
} else {
return Allocate::make(
op->buffer_var, op->type,
extents, condition, body,
op->new_expr, op->free_function);
}
}
// inject vthread loop
Stmt InjectVTLoop(Stmt stmt, bool before_mutation) {
CHECK(!vt_loop_injected_);
// reset the flags
visit_touched_var_ = false;
vt_loop_injected_ = true;
if (before_mutation) {
stmt = this->Mutate(stmt);
}
// reset the flags after processing.
vt_loop_injected_ = false;
visit_touched_var_ = false;
if (max_loop_depth_ == 0) {
// do unrolling if it is inside innermost content.
Stmt blk = Substitute(stmt, {{var_, make_zero(var_.type())}});
for (int i = 1; i < num_threads_; ++i) {
blk = Block::make(
blk, Substitute(stmt, {{var_, make_const(var_.type(), i)}}));
}
return blk;
} else {
// insert a for loop
Var idx(var_->name_hint + ".s", var_->type);
stmt = Substitute(stmt, {{var_, idx}});
return For::make(idx, 0, num_threads_,
ForType::Serial, DeviceAPI::None, stmt);
}
}
private:
// vthread variable
Var var_;
// the threads/lanes
int num_threads_;
// Index rewriting strategy
int index_rewrite_strategy_{1};
// whethe the loop is already injected.
bool vt_loop_injected_{false};
// whether current expression get touched.
bool visit_touched_var_{false};
// the counter of loops in after mutation.
int max_loop_depth_{0};
// The variables that get touched.
std::unordered_set<const Variable*> touched_var_;
// The allocations that get touched -> extent
std::unordered_map<const Variable*, Expr> touched_alloc_;
};
class VirtualThreadInjector : public IRMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>();
if (op->type_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
int nthread = static_cast<int>(op->value.as<IntImm>()->value);
VarTouchedAnalysis vs;
auto touched = vs.TouchedVar(op->body, iv->var.get());
VTInjector injecter(iv->var, nthread, touched);
return injecter.Mutate(op->body);
} else {
return stmt;
}
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
LOG(FATAL) << "Need to call StorageFlatten first";
return s;
}
};
Stmt InjectVirtualThread(Stmt stmt) {
stmt = VirtualThreadInjector().Mutate(stmt);
return ConvertSSA(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -77,6 +77,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Free);
Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
......@@ -212,6 +213,17 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
}
}
Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt first = this->Mutate(op->first);
Stmt rest = this->Mutate(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Let)
......@@ -370,16 +382,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return ProducerConsumer::make(op->func, op->is_producer, body);
}
})
.set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) {
Stmt first = m->Mutate(op->first);
Stmt rest = m->Mutate(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
})
.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) {
Expr v = m->Mutate(op->value);
if (v.same_as(op->value)) {
......
/*!
* Copyright (c) 2017 by Contributors
* \file lift_allocate.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include "./ir_util.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
using runtime::ThreadScope;
class AllocateLifter : public IRMutator {
public:
Stmt Lift(Stmt stmt) {
stmt = this->Mutate(stmt);
StorageScope key; key.rank = 0;
stmt = MergeNest(allocs_[key], stmt);
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->type_key != attr::virtual_thread)
<< "InjectVirtualThread before LiftStorageAlloc";
if (op->type_key == attr::storage_scope) {
StorageScope sc = StorageScope::make(op->value.as<StringImm>()->value);
allocs_[sc].emplace_back(
AttrStmt::make(
op->node, attr::storage_scope,
op->value, Evaluate::make(0)));
storage_scope_[op->node.get()] = sc;
return this->Mutate(op->body);
} else if (op->type_key == attr::thread_extent) {
IterVar iv(op->node.node_);
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back();
op = stmt.as<AttrStmt>();
bool first_scope = true;
for (const ThreadScope& t : curr_thread_scope_) {
if (t.rank == ts.rank) first_scope = false;
}
if (first_scope) {
StorageScope key;
key.rank = ts.rank + 1;
std::vector<Stmt>& vec = allocs_[key];
if (vec.size() != 0) {
Stmt body = MergeNest(vec, op->body);
vec.clear();
return AttrStmt::make(
op->node, op->type_key, op->value, body);
}
}
return stmt;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For* op, const Stmt& s) final {
CHECK(op->for_type != ForType::Vectorized)
<< "VectorizeLoop before LiftStorageAlloc";
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = storage_scope_.find(op->buffer_var.get());
CHECK(it != storage_scope_.end());
allocs_[it->second].emplace_back(
Allocate::make(
op->buffer_var, op->type, op->extents, op->condition,
Evaluate::make(0)));
return this->Mutate(op->body);
}
private:
// storage scope of internal allocation.
std::unordered_map<const Node*, StorageScope> storage_scope_;
// The current thread scope.
std::vector<ThreadScope> curr_thread_scope_;
// The allocations by rank
std::unordered_map<StorageScope, std::vector<Stmt> > allocs_;
};
Stmt LiftAllocate(Stmt stmt) {
return AllocateLifter().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -6,7 +6,6 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include "./ir_util.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
......@@ -61,47 +60,18 @@ class StorageFlattener : public IRMutator {
}
}
Stmt Flatten(Stmt stmt) {
stmt = this->Mutate(stmt);
StorageScope key; key.rank = 0;
if (move_alloc_out_) {
StorageScope key; key.rank = 0;
stmt = MergeNest(allocs_[key], stmt);
}
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == "realize_scope") {
if (op->type_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body);
} else if (op->type_key == "scope") {
} else if (op->type_key == attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back();
op = stmt.as<AttrStmt>();
bool first_scope = true;
for (const ThreadScope& t : curr_thread_scope_) {
if (t.rank == ts.rank) first_scope = false;
}
if (first_scope && move_alloc_out_) {
StorageScope key;
key.rank = ts.rank + 1;
std::vector<Stmt>& vec = allocs_[key];
if (vec.size() != 0) {
Stmt body = MergeNest(vec, op->body);
vec.clear();
return AttrStmt::make(
op->node, op->type_key, op->value, body);
}
}
return stmt;
}
}
return IRMutator::Mutate_(op, s);
}
......@@ -140,39 +110,24 @@ class StorageFlattener : public IRMutator {
// deduce current storage scope.
auto it = storage_scope_.find(op->func.get());
CHECK(it != storage_scope_.end());
StorageScope key; key.rank = 0;
const std::string& skey = it->second;
if (skey.length() == 0) {
StorageScope skey;
const std::string& strkey = it->second;
if (strkey.length() == 0) {
if (curr_thread_scope_.size() != 0) {
key.rank = curr_thread_scope_.back().rank + 1;
skey.rank = curr_thread_scope_.back().rank + 1;
}
} else {
key = StorageScope::make(skey);
skey = StorageScope::make(strkey);
}
if (move_alloc_out_) {
allocs_[key].push_back(
AttrStmt::make(
e.buffer->data, "storage_scope",
StringImm::make(key.to_string()),
Evaluate::make(0)));
allocs_[key].push_back(
Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true),
Evaluate::make(0)));
return body;
} else {
Stmt ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
ret = AttrStmt::make(
e.buffer->data, "storage_scope",
StringImm::make(key.to_string()), ret);
e.buffer->data, attr::storage_scope,
StringImm::make(skey.to_string()), ret);
return ret;
}
}
}
Expr Mutate_(const Call* op, const Expr& olde) final {
Expr expr = IRMutator::Mutate_(op, olde);
......@@ -217,20 +172,16 @@ class StorageFlattener : public IRMutator {
}
}
};
// whether move allocation to the outmost scope as possible.
bool move_alloc_out_{true};
// The buffer assignment map
std::unordered_map<TensorKey, BufferEntry> buf_map_;
std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope.
std::vector<ThreadScope> curr_thread_scope_;
// The allocations by rank
std::unordered_map<StorageScope, std::vector<Stmt> > allocs_;
};
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer) {
stmt = StorageFlattener(extern_buffer).Flatten(stmt);
stmt = StorageFlattener(extern_buffer).Mutate(stmt);
return stmt;
}
......
......@@ -62,7 +62,11 @@ struct ThreadScope {
*/
static ThreadScope make(const std::string& s) {
ThreadScope r;
if (s.compare(0, 9, "blockIdx.") == 0) {
if (s == "vthread") {
// virtual thread at the same level as local
r.rank = 1;
r.dim_index = -1;
} else if (s.compare(0, 9, "blockIdx.") == 0) {
r.rank = 0;
r.dim_index = static_cast<int>(s[9] - 'x');
} else if (s.compare(0, 10, "threadIdx.") == 0) {
......
......@@ -203,18 +203,27 @@ MakeLoopNest(const Stage& sch,
nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op));
}
} else if (iv->thread_tag == "vthread") {
// virtual thread
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
CHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var;
} else {
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, "thread_extent", dom->extent, no_op));
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op));
value_map[iv] = var;
}
if (!reduce_init_loop) {
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, "scope", iv->var, no_op));
AttrStmt::make(iv, ir::attr::scope, iv->var, no_op));
}
}
// message passing to get offset of root iter vars.
......
import tvm
def test_virtual_thread():
m = tvm.Var('m')
A = tvm.placeholder((m, ), name='A')
A1 = tvm.compute((m,), lambda i: A[i], name='A1')
A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2')
s = tvm.Schedule(A2.op)
vx = tvm.IterVar((0, 2), "vx", thread_tag="vthread")
xo, xi = s[A2].split(A2.op.axis[0], outer=vx)
xo, xi = s[A2].split(xi, 8)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.InjectVirtualThread(stmt)
print(stmt)
if __name__ == "__main__":
test_virtual_thread()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment