Commit 5072efae by Tianqi Chen Committed by GitHub

[PASS] Improve vthread injection. (#411)

parent b0d9f299
...@@ -257,6 +257,11 @@ constexpr const char* tvm_if_then_else = "tvm_if_then_else"; ...@@ -257,6 +257,11 @@ constexpr const char* tvm_if_then_else = "tvm_if_then_else";
*/ */
constexpr const char* tvm_access_ptr = "tvm_access_ptr"; constexpr const char* tvm_access_ptr = "tvm_access_ptr";
/*! /*!
* \brief Return a unique context id, used for hint of workspace separation.
* Different context id ganrantees not having overlapping workspace.
*/
constexpr const char* tvm_context_id = "tvm_context_id";
/*!
* \brief tvm_tuple is not an actual function and cannot codegen. * \brief tvm_tuple is not an actual function and cannot codegen.
* It is used to represent tuple structure in value field of AttrStmt, * It is used to represent tuple structure in value field of AttrStmt,
* for the sake of giving hint to optimization. * for the sake of giving hint to optimization.
......
...@@ -106,7 +106,8 @@ MakeLoopNest(const Stage& stage, ...@@ -106,7 +106,8 @@ MakeLoopNest(const Stage& stage,
it_attr->prefetch_offset[j], no_op)); it_attr->prefetch_offset[j], no_op));
} }
} }
} else if (bind_iv->thread_tag == "vthread") { } else if (bind_iv->thread_tag == "vthread" ||
bind_iv->thread_tag == "cthread") {
// virtual thread // virtual thread
// Always restrict threaded IterVar to starts from 0. // Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min)); CHECK(is_zero(dom->min));
......
...@@ -69,7 +69,7 @@ class DoubleBufferInjector : public IRMutator { ...@@ -69,7 +69,7 @@ class DoubleBufferInjector : public IRMutator {
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = dbuffer_info_.find(op->buffer_var.get()); auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) { if (it != dbuffer_info_.end()) {
it->second.size = arith::ComputeReduce<Mul>(op->extents); it->second.stride = arith::ComputeReduce<Mul>(op->extents) * op->type.lanes();
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>(); op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].type(), 2)}; Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
...@@ -126,10 +126,10 @@ class DoubleBufferInjector : public IRMutator { ...@@ -126,10 +126,10 @@ class DoubleBufferInjector : public IRMutator {
if (it != dbuffer_info_.end()) { if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second; const StorageEntry& e = it->second;
CHECK(in_double_buffer_scope_); CHECK(in_double_buffer_scope_);
CHECK(e.size.defined()); CHECK(e.stride.defined());
return Store::make(op->buffer_var, return Store::make(op->buffer_var,
op->value, op->value,
e.switch_write_var * e.size + op->index, e.switch_write_var * e.stride + op->index,
op->predicate); op->predicate);
} else { } else {
return stmt; return stmt;
...@@ -142,11 +142,11 @@ class DoubleBufferInjector : public IRMutator { ...@@ -142,11 +142,11 @@ class DoubleBufferInjector : public IRMutator {
auto it = dbuffer_info_.find(op->buffer_var.get()); auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) { if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second; const StorageEntry& e = it->second;
CHECK(e.size.defined()); CHECK(e.stride.defined());
CHECK(e.switch_read_var.defined()); CHECK(e.switch_read_var.defined());
return Load::make(op->type, return Load::make(op->type,
op->buffer_var, op->buffer_var,
e.switch_read_var * e.size + op->index, e.switch_read_var * e.stride + op->index,
op->predicate); op->predicate);
} else { } else {
return expr; return expr;
...@@ -194,7 +194,7 @@ class DoubleBufferInjector : public IRMutator { ...@@ -194,7 +194,7 @@ class DoubleBufferInjector : public IRMutator {
// Storage entry for those who need double buffering. // Storage entry for those who need double buffering.
struct StorageEntry { struct StorageEntry {
// The size of the buffer // The size of the buffer
Expr size; Expr stride;
// The loop we need // The loop we need
const For* loop{nullptr}; const For* loop{nullptr};
// The switch variable. // The switch variable.
......
...@@ -130,22 +130,29 @@ class VTInjector : public IRMutator { ...@@ -130,22 +130,29 @@ class VTInjector : public IRMutator {
// constructor // constructor
VTInjector(Var var, VTInjector(Var var,
int num_threads, int num_threads,
std::unordered_set<const Variable*> touched_var) const std::unordered_set<const Variable*>& touched_var,
: var_(var), num_threads_(num_threads), touched_var_(touched_var) { bool allow_share)
: var_(var), num_threads_(num_threads),
touched_var_(touched_var), allow_share_(allow_share) {
} }
// Inject VTLoop when needed. // Inject VTLoop when needed.
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
CHECK(!visit_touched_var_) CHECK(!visit_touched_var_)
<< stmt->type_key() << stmt; << stmt->type_key() << stmt;
stmt = IRMutator::Mutate(stmt); stmt = IRMutator::Mutate(stmt);
if (visit_touched_var_) { if (visit_touched_var_ || trigger_base_inject_) {
if (!vt_loop_injected_) return InjectVTLoop(stmt, false); if (!vt_loop_injected_) {
return InjectVTLoop(stmt, false);
}
visit_touched_var_ = false; visit_touched_var_ = false;
trigger_base_inject_ = false;
} }
return stmt; return stmt;
} }
// Variable // Variable
Expr Mutate_(const Variable *op, const Expr& e) final { Expr Mutate_(const Variable *op, const Expr& e) final {
CHECK(!alloc_remap_.count(op))
<< "Buffer address may get rewritten in virtual thread";
if (touched_var_.count(op)) { if (touched_var_.count(op)) {
visit_touched_var_ = true; visit_touched_var_ = true;
} }
...@@ -161,8 +168,8 @@ class VTInjector : public IRMutator { ...@@ -161,8 +168,8 @@ class VTInjector : public IRMutator {
if (touched_var_.count(op->buffer_var.get())) { if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true; visit_touched_var_ = true;
} }
auto it = touched_alloc_.find(op->buffer_var.get()); auto it = alloc_remap_.find(op->buffer_var.get());
if (it != touched_alloc_.end()) { if (it != alloc_remap_.end()) {
return Load::make(op->type, op->buffer_var, return Load::make(op->type, op->buffer_var,
RewriteIndex(op->index, it->second), RewriteIndex(op->index, it->second),
op->predicate); op->predicate);
...@@ -170,6 +177,34 @@ class VTInjector : public IRMutator { ...@@ -170,6 +177,34 @@ class VTInjector : public IRMutator {
return expr; return expr;
} }
} }
// Expression.
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
auto it = alloc_remap_.find(buffer);
if (it == alloc_remap_.end()) return IRMutator::Mutate_(op, e);
visit_touched_var_ = true;
Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]);
Expr stride = arith::ComputeExpr<Div>(
it->second, make_const(offset.type(), dtype.lanes()));
offset = stride * var_ + offset;
return Call::make(
op->type, op->name,
{op->args[0], op->args[1], offset, extent, op->args[4]},
op->call_type);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
return allow_share_ ? e : var_;
} else {
return IRMutator::Mutate_(op, e);
}
}
Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
trigger_base_inject_ = !allow_share_;
return IRMutator::Mutate_(op, s);
}
// Store // Store
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
...@@ -177,8 +212,9 @@ class VTInjector : public IRMutator { ...@@ -177,8 +212,9 @@ class VTInjector : public IRMutator {
if (touched_var_.count(op->buffer_var.get())) { if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true; visit_touched_var_ = true;
} }
auto it = touched_alloc_.find(op->buffer_var.get()); trigger_base_inject_ = !allow_share_;
if (it != touched_alloc_.end()) { auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
return Store::make(op->buffer_var, return Store::make(op->buffer_var,
op->value, op->value,
RewriteIndex(op->index, it->second), RewriteIndex(op->index, it->second),
...@@ -190,7 +226,10 @@ class VTInjector : public IRMutator { ...@@ -190,7 +226,10 @@ class VTInjector : public IRMutator {
// Attribute // Attribute
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Expr value = Mutate(op->value); Expr value = Mutate(op->value);
if (visit_touched_var_) { if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
} else if (!allow_share_ && !vt_loop_injected_ &&
op->attr_key == attr::coproc_uop_scope) {
return InjectVTLoop(s, true); return InjectVTLoop(s, true);
} else { } else {
Stmt body = Mutate(op->body); Stmt body = Mutate(op->body);
...@@ -299,24 +338,19 @@ class VTInjector : public IRMutator { ...@@ -299,24 +338,19 @@ class VTInjector : public IRMutator {
visit_touched_var_ = false; visit_touched_var_ = false;
Stmt body; Stmt body;
if (touched_var_.count(op->buffer_var.get())) { // always rewrite if not allow sharing.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension. // place v on highest dimension.
Expr stride = extents[0]; Expr stride = arith::ComputeReduce<Mul>(op->extents) * op->type.lanes();
for (size_t i = 1; i < extents.size(); ++i) {
stride = arith::ComputeExpr<Mul>(stride, extents[i]);
}
if (op->type.lanes() != 0) {
stride = stride * op->type.lanes();
}
Array<Expr> other; Array<Expr> other;
other.push_back(num_threads_); other.push_back(make_const(op->extents[0].type(), num_threads_));
for (Expr e : extents) { for (Expr e : extents) {
other.push_back(e); other.push_back(e);
} }
extents = other; extents = other;
changed = true; changed = true;
// mark this buffer get touched. // mark this buffer get touched.
touched_alloc_[op->buffer_var.get()] = stride; alloc_remap_[op->buffer_var.get()] = stride;
// Mutate the body. // Mutate the body.
body = Mutate(op->body); body = Mutate(op->body);
} else { } else {
...@@ -340,6 +374,7 @@ class VTInjector : public IRMutator { ...@@ -340,6 +374,7 @@ class VTInjector : public IRMutator {
CHECK(!vt_loop_injected_); CHECK(!vt_loop_injected_);
// reset the flags // reset the flags
visit_touched_var_ = false; visit_touched_var_ = false;
trigger_base_inject_ = false;
vt_loop_injected_ = true; vt_loop_injected_ = true;
if (before_mutation) { if (before_mutation) {
stmt = this->Mutate(stmt); stmt = this->Mutate(stmt);
...@@ -359,7 +394,8 @@ class VTInjector : public IRMutator { ...@@ -359,7 +394,8 @@ class VTInjector : public IRMutator {
// insert a for loop // insert a for loop
Var idx(var_->name_hint + ".s", var_->type); Var idx(var_->name_hint + ".s", var_->type);
stmt = Substitute(stmt, {{var_, idx}}); stmt = Substitute(stmt, {{var_, idx}});
return For::make(idx, 0, num_threads_, return For::make(idx, make_zero(idx.type()),
make_const(idx.type(), num_threads_),
ForType::Serial, DeviceAPI::None, stmt); ForType::Serial, DeviceAPI::None, stmt);
} }
} }
...@@ -373,12 +409,16 @@ class VTInjector : public IRMutator { ...@@ -373,12 +409,16 @@ class VTInjector : public IRMutator {
bool vt_loop_injected_{false}; bool vt_loop_injected_{false};
// whether current expression get touched. // whether current expression get touched.
bool visit_touched_var_{false}; bool visit_touched_var_{false};
// Trigger base stmt
bool trigger_base_inject_{false};
// the counter of loops in after mutation. // the counter of loops in after mutation.
int max_loop_depth_{0}; int max_loop_depth_{0};
// The variables that get touched. // The variables that get touched.
std::unordered_set<const Variable*> touched_var_; const std::unordered_set<const Variable*>& touched_var_;
// Whether allow shareding.
bool allow_share_;
// The allocations that get touched -> extent // The allocations that get touched -> extent
std::unordered_map<const Variable*, Expr> touched_alloc_; std::unordered_map<const Variable*, Expr> alloc_remap_;
}; };
...@@ -389,10 +429,11 @@ class VirtualThreadInjector : public IRMutator { ...@@ -389,10 +429,11 @@ class VirtualThreadInjector : public IRMutator {
op = stmt.as<AttrStmt>(); op = stmt.as<AttrStmt>();
if (op->attr_key == attr::virtual_thread) { if (op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
bool allow_share = iv->thread_tag == "vthread";
int nthread = static_cast<int>(op->value.as<IntImm>()->value); int nthread = static_cast<int>(op->value.as<IntImm>()->value);
VarTouchedAnalysis vs; VarTouchedAnalysis vs;
auto touched = vs.TouchedVar(op->body, iv->var.get()); auto touched = vs.TouchedVar(op->body, iv->var.get());
VTInjector injecter(iv->var, nthread, touched); VTInjector injecter(iv->var, nthread, touched, allow_share);
return injecter.Mutate(op->body); return injecter.Mutate(op->body);
} else { } else {
return stmt; return stmt;
......
...@@ -140,6 +140,8 @@ class BuiltinLower : public IRMutator { ...@@ -140,6 +140,8 @@ class BuiltinLower : public IRMutator {
return MakeShape(op, e); return MakeShape(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) { } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
return MakeArray(op, e); return MakeArray(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
return make_zero(op->type);
} else { } else {
return IRMutator::Mutate_(op, e); return IRMutator::Mutate_(op, e);
} }
......
import tvm
def test_vthread():
dtype = 'int64'
n = 100
m = 4
nthread = 2
def get_vthread(name):
tx = tvm.thread_axis(name)
ty = tvm.thread_axis(name)
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="C")
with ib.for_range(0, n) as i:
ib.scope_attr(tx, "virtual_thread", nthread)
ib.scope_attr(ty, "virtual_thread", nthread)
B = ib.allocate("float32", m, name="B", scope="shared")
B[i] = A[i * nthread + tx]
bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
ib.emit(tvm.call_extern("int32", "Run",
bbuffer.access_ptr("r"),
tvm.call_pure_intrin("int32", "tvm_context_id")))
C[i * nthread + tx] = B[i] + 1
return ib.get()
stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("vthread"))
assert stmt.body.body.extents[0].value == 2
stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("cthread"))
assert len(stmt.body.body.extents) == 3
if __name__ == "__main__":
test_vthread()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment