Commit b8f0ec50 by Tianqi Chen Committed by GitHub

[LANG/PASS] InjectVirtualThread (#38)

parent 526ff04c
......@@ -49,6 +49,30 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
/*!
* \brief Mark scope of iteration variable, used by Schedule.
*/
constexpr const char* scope = "scope";
/*!
* \brief Mark launching extent of thread, used by device API.
*/
constexpr const char* thread_extent = "thread_extent";
/*!
* \brief Mark launching of a virtual thread.
*/
constexpr const char* virtual_thread = "virtual_thread";
/*!
* \brief Mark storage scope of buffers
*/
constexpr const char* storage_scope = "storage_scope";
/*!
* \brief Mark storage scope of realizations
*/
constexpr const char* realize_scope = "realize_scope";
} // namespace attr
/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
// Most of the intrinsics is to enab
......
......@@ -63,6 +63,7 @@ class IRMutator {
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
......
......@@ -100,6 +100,7 @@ Stmt Inline(Stmt stmt,
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);
......@@ -108,16 +109,35 @@ Stmt StorageFlatten(Stmt stmt,
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
* \return Transformed stmt.
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);
/*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
* \return Transformed stmt.
*/
Stmt VectorizeLoop(Stmt stmt);
/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed.
* \return Transformed stmt.
*/
Stmt InjectVirtualThread(Stmt stmt);
/*!
* \brief Lift storage allocation to relevant outpost location
*
* Only do this after vectorization and virtual thread injection completes.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt LiftAllocate(Stmt stmt);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
......
......@@ -70,6 +70,8 @@ def build(sch,
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.LiftAllocate(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
......
......@@ -67,6 +67,8 @@ REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(LiftAllocate);
REGISTER_PASS1(InjectVirtualThread);
} // namespace ir
} // namespace tvm
......@@ -288,7 +288,8 @@ class Canonical::Internal : public IRMutator {
}
// AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->type_key == "thread_extent") {
if (op->type_key == attr::thread_extent ||
op->type_key == attr::virtual_thread) {
++level_counter_;
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
......
......@@ -743,7 +743,7 @@ void CodeGenC::PrintStmt(const Allocate* op) {
}
void CodeGenC::PrintStmt(const AttrStmt* op) {
if (op->type_key == "scope") {
if (op->type_key == ir::attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
if (!var_idmap_.count(iv->var.get())) {
......@@ -756,7 +756,7 @@ void CodeGenC::PrintStmt(const AttrStmt* op) {
stream << ";\n";
}
}
} else if (op->type_key == "storage_scope") {
} else if (op->type_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
......
......@@ -9,6 +9,7 @@
#include <string>
#include "./codegen_cuda.h"
#include "./codegen_stack_vm.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/cuda/cuda_common.h"
#include "../runtime/cuda/cuda_module.h"
......@@ -22,6 +23,17 @@ std::string CodeGenCUDA::Compile(
return CodeGenC::Compile(f, output_ssa);
}
void CodeGenCUDA::PrintStmt(const ir::For* op) {
int ext;
CHECK(is_zero(op->min));
if (arith::GetConstInt(op->extent, &ext) &&
ext <= max_auto_unroll_) {
PrintIndent();
stream << "#pragma unroll\n";
}
CodeGenC::PrintStmt(op);
}
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
......
......@@ -27,6 +27,7 @@ class CodeGenCUDA : public CodeGenC {
bool output_ssa);
// override behavior
void PrintStmt(const ir::For* op) final;
void PrintStorageSync(const std::string& sync) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
......@@ -37,6 +38,11 @@ class CodeGenCUDA : public CodeGenC {
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final;
private:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int max_auto_unroll_{8};
};
} // namespace codegen
......
......@@ -77,6 +77,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Free);
Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
......@@ -212,6 +213,17 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
}
}
Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt first = this->Mutate(op->first);
Stmt rest = this->Mutate(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Let)
......@@ -370,16 +382,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return ProducerConsumer::make(op->func, op->is_producer, body);
}
})
.set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) {
Stmt first = m->Mutate(op->first);
Stmt rest = m->Mutate(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
})
.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) {
Expr v = m->Mutate(op->value);
if (v.same_as(op->value)) {
......
/*!
* Copyright (c) 2017 by Contributors
* \file lift_allocate.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include "./ir_util.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
using runtime::ThreadScope;
class AllocateLifter : public IRMutator {
public:
Stmt Lift(Stmt stmt) {
stmt = this->Mutate(stmt);
StorageScope key; key.rank = 0;
stmt = MergeNest(allocs_[key], stmt);
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->type_key != attr::virtual_thread)
<< "InjectVirtualThread before LiftStorageAlloc";
if (op->type_key == attr::storage_scope) {
StorageScope sc = StorageScope::make(op->value.as<StringImm>()->value);
allocs_[sc].emplace_back(
AttrStmt::make(
op->node, attr::storage_scope,
op->value, Evaluate::make(0)));
storage_scope_[op->node.get()] = sc;
return this->Mutate(op->body);
} else if (op->type_key == attr::thread_extent) {
IterVar iv(op->node.node_);
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back();
op = stmt.as<AttrStmt>();
bool first_scope = true;
for (const ThreadScope& t : curr_thread_scope_) {
if (t.rank == ts.rank) first_scope = false;
}
if (first_scope) {
StorageScope key;
key.rank = ts.rank + 1;
std::vector<Stmt>& vec = allocs_[key];
if (vec.size() != 0) {
Stmt body = MergeNest(vec, op->body);
vec.clear();
return AttrStmt::make(
op->node, op->type_key, op->value, body);
}
}
return stmt;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For* op, const Stmt& s) final {
CHECK(op->for_type != ForType::Vectorized)
<< "VectorizeLoop before LiftStorageAlloc";
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = storage_scope_.find(op->buffer_var.get());
CHECK(it != storage_scope_.end());
allocs_[it->second].emplace_back(
Allocate::make(
op->buffer_var, op->type, op->extents, op->condition,
Evaluate::make(0)));
return this->Mutate(op->body);
}
private:
// storage scope of internal allocation.
std::unordered_map<const Node*, StorageScope> storage_scope_;
// The current thread scope.
std::vector<ThreadScope> curr_thread_scope_;
// The allocations by rank
std::unordered_map<StorageScope, std::vector<Stmt> > allocs_;
};
Stmt LiftAllocate(Stmt stmt) {
return AllocateLifter().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -6,7 +6,6 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include "./ir_util.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
......@@ -61,46 +60,17 @@ class StorageFlattener : public IRMutator {
}
}
Stmt Flatten(Stmt stmt) {
stmt = this->Mutate(stmt);
StorageScope key; key.rank = 0;
if (move_alloc_out_) {
StorageScope key; key.rank = 0;
stmt = MergeNest(allocs_[key], stmt);
}
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == "realize_scope") {
if (op->type_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body);
} else if (op->type_key == "scope") {
} else if (op->type_key == attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back();
op = stmt.as<AttrStmt>();
bool first_scope = true;
for (const ThreadScope& t : curr_thread_scope_) {
if (t.rank == ts.rank) first_scope = false;
}
if (first_scope && move_alloc_out_) {
StorageScope key;
key.rank = ts.rank + 1;
std::vector<Stmt>& vec = allocs_[key];
if (vec.size() != 0) {
Stmt body = MergeNest(vec, op->body);
vec.clear();
return AttrStmt::make(
op->node, op->type_key, op->value, body);
}
}
return stmt;
}
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back();
return stmt;
}
return IRMutator::Mutate_(op, s);
}
......@@ -140,37 +110,22 @@ class StorageFlattener : public IRMutator {
// deduce current storage scope.
auto it = storage_scope_.find(op->func.get());
CHECK(it != storage_scope_.end());
StorageScope key; key.rank = 0;
const std::string& skey = it->second;
if (skey.length() == 0) {
StorageScope skey;
const std::string& strkey = it->second;
if (strkey.length() == 0) {
if (curr_thread_scope_.size() != 0) {
key.rank = curr_thread_scope_.back().rank + 1;
skey.rank = curr_thread_scope_.back().rank + 1;
}
} else {
key = StorageScope::make(skey);
}
if (move_alloc_out_) {
allocs_[key].push_back(
AttrStmt::make(
e.buffer->data, "storage_scope",
StringImm::make(key.to_string()),
Evaluate::make(0)));
allocs_[key].push_back(
Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true),
Evaluate::make(0)));
return body;
} else {
Stmt ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
ret = AttrStmt::make(
e.buffer->data, "storage_scope",
StringImm::make(key.to_string()), ret);
return ret;
skey = StorageScope::make(strkey);
}
Stmt ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
ret = AttrStmt::make(
e.buffer->data, attr::storage_scope,
StringImm::make(skey.to_string()), ret);
return ret;
}
}
......@@ -217,20 +172,16 @@ class StorageFlattener : public IRMutator {
}
}
};
// whether move allocation to the outmost scope as possible.
bool move_alloc_out_{true};
// The buffer assignment map
std::unordered_map<TensorKey, BufferEntry> buf_map_;
std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope.
std::vector<ThreadScope> curr_thread_scope_;
// The allocations by rank
std::unordered_map<StorageScope, std::vector<Stmt> > allocs_;
};
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer) {
stmt = StorageFlattener(extern_buffer).Flatten(stmt);
stmt = StorageFlattener(extern_buffer).Mutate(stmt);
return stmt;
}
......
......@@ -62,7 +62,11 @@ struct ThreadScope {
*/
static ThreadScope make(const std::string& s) {
ThreadScope r;
if (s.compare(0, 9, "blockIdx.") == 0) {
if (s == "vthread") {
// virtual thread at the same level as local
r.rank = 1;
r.dim_index = -1;
} else if (s.compare(0, 9, "blockIdx.") == 0) {
r.rank = 0;
r.dim_index = static_cast<int>(s[9] - 'x');
} else if (s.compare(0, 10, "threadIdx.") == 0) {
......
......@@ -203,18 +203,27 @@ MakeLoopNest(const Stage& sch,
nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op));
}
} else if (iv->thread_tag == "vthread") {
// virtual thread
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
CHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var;
} else {
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, "thread_extent", dom->extent, no_op));
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op));
value_map[iv] = var;
}
if (!reduce_init_loop) {
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, "scope", iv->var, no_op));
AttrStmt::make(iv, ir::attr::scope, iv->var, no_op));
}
}
// message passing to get offset of root iter vars.
......
import tvm
def test_virtual_thread():
m = tvm.Var('m')
A = tvm.placeholder((m, ), name='A')
A1 = tvm.compute((m,), lambda i: A[i], name='A1')
A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2')
s = tvm.Schedule(A2.op)
vx = tvm.IterVar((0, 2), "vx", thread_tag="vthread")
xo, xi = s[A2].split(A2.op.axis[0], outer=vx)
xo, xi = s[A2].split(xi, 8)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.InjectVirtualThread(stmt)
print(stmt)
if __name__ == "__main__":
test_virtual_thread()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment