Commit 5072efae by Tianqi Chen Committed by GitHub

[PASS] Improve vthread injection. (#411)

parent b0d9f299
......@@ -257,6 +257,11 @@ constexpr const char* tvm_if_then_else = "tvm_if_then_else";
*/
constexpr const char* tvm_access_ptr = "tvm_access_ptr";
/*!
* \brief Return a unique context id, used for hint of workspace separation.
* Different context id ganrantees not having overlapping workspace.
*/
constexpr const char* tvm_context_id = "tvm_context_id";
/*!
* \brief tvm_tuple is not an actual function and cannot codegen.
* It is used to represent tuple structure in value field of AttrStmt,
* for the sake of giving hint to optimization.
......
......@@ -106,7 +106,8 @@ MakeLoopNest(const Stage& stage,
it_attr->prefetch_offset[j], no_op));
}
}
} else if (bind_iv->thread_tag == "vthread") {
} else if (bind_iv->thread_tag == "vthread" ||
bind_iv->thread_tag == "cthread") {
// virtual thread
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
......
......@@ -69,7 +69,7 @@ class DoubleBufferInjector : public IRMutator {
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
it->second.size = arith::ComputeReduce<Mul>(op->extents);
it->second.stride = arith::ComputeReduce<Mul>(op->extents) * op->type.lanes();
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
......@@ -126,10 +126,10 @@ class DoubleBufferInjector : public IRMutator {
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(in_double_buffer_scope_);
CHECK(e.size.defined());
CHECK(e.stride.defined());
return Store::make(op->buffer_var,
op->value,
e.switch_write_var * e.size + op->index,
e.switch_write_var * e.stride + op->index,
op->predicate);
} else {
return stmt;
......@@ -142,11 +142,11 @@ class DoubleBufferInjector : public IRMutator {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(e.size.defined());
CHECK(e.stride.defined());
CHECK(e.switch_read_var.defined());
return Load::make(op->type,
op->buffer_var,
e.switch_read_var * e.size + op->index,
e.switch_read_var * e.stride + op->index,
op->predicate);
} else {
return expr;
......@@ -194,7 +194,7 @@ class DoubleBufferInjector : public IRMutator {
// Storage entry for those who need double buffering.
struct StorageEntry {
// The size of the buffer
Expr size;
Expr stride;
// The loop we need
const For* loop{nullptr};
// The switch variable.
......
......@@ -130,22 +130,29 @@ class VTInjector : public IRMutator {
// constructor
VTInjector(Var var,
int num_threads,
std::unordered_set<const Variable*> touched_var)
: var_(var), num_threads_(num_threads), touched_var_(touched_var) {
const std::unordered_set<const Variable*>& touched_var,
bool allow_share)
: var_(var), num_threads_(num_threads),
touched_var_(touched_var), allow_share_(allow_share) {
}
// Inject VTLoop when needed.
Stmt Mutate(Stmt stmt) final {
CHECK(!visit_touched_var_)
<< stmt->type_key() << stmt;
stmt = IRMutator::Mutate(stmt);
if (visit_touched_var_) {
if (!vt_loop_injected_) return InjectVTLoop(stmt, false);
if (visit_touched_var_ || trigger_base_inject_) {
if (!vt_loop_injected_) {
return InjectVTLoop(stmt, false);
}
visit_touched_var_ = false;
trigger_base_inject_ = false;
}
return stmt;
}
// Variable
Expr Mutate_(const Variable *op, const Expr& e) final {
CHECK(!alloc_remap_.count(op))
<< "Buffer address may get rewritten in virtual thread";
if (touched_var_.count(op)) {
visit_touched_var_ = true;
}
......@@ -161,8 +168,8 @@ class VTInjector : public IRMutator {
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
}
auto it = touched_alloc_.find(op->buffer_var.get());
if (it != touched_alloc_.end()) {
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
return Load::make(op->type, op->buffer_var,
RewriteIndex(op->index, it->second),
op->predicate);
......@@ -170,6 +177,34 @@ class VTInjector : public IRMutator {
return expr;
}
}
// Expression.
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
auto it = alloc_remap_.find(buffer);
if (it == alloc_remap_.end()) return IRMutator::Mutate_(op, e);
visit_touched_var_ = true;
Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]);
Expr stride = arith::ComputeExpr<Div>(
it->second, make_const(offset.type(), dtype.lanes()));
offset = stride * var_ + offset;
return Call::make(
op->type, op->name,
{op->args[0], op->args[1], offset, extent, op->args[4]},
op->call_type);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
return allow_share_ ? e : var_;
} else {
return IRMutator::Mutate_(op, e);
}
}
Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
trigger_base_inject_ = !allow_share_;
return IRMutator::Mutate_(op, s);
}
// Store
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
......@@ -177,8 +212,9 @@ class VTInjector : public IRMutator {
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
}
auto it = touched_alloc_.find(op->buffer_var.get());
if (it != touched_alloc_.end()) {
trigger_base_inject_ = !allow_share_;
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
return Store::make(op->buffer_var,
op->value,
RewriteIndex(op->index, it->second),
......@@ -190,7 +226,10 @@ class VTInjector : public IRMutator {
// Attribute
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Expr value = Mutate(op->value);
if (visit_touched_var_) {
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
} else if (!allow_share_ && !vt_loop_injected_ &&
op->attr_key == attr::coproc_uop_scope) {
return InjectVTLoop(s, true);
} else {
Stmt body = Mutate(op->body);
......@@ -299,24 +338,19 @@ class VTInjector : public IRMutator {
visit_touched_var_ = false;
Stmt body;
if (touched_var_.count(op->buffer_var.get())) {
// always rewrite if not allow sharing.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension.
Expr stride = extents[0];
for (size_t i = 1; i < extents.size(); ++i) {
stride = arith::ComputeExpr<Mul>(stride, extents[i]);
}
if (op->type.lanes() != 0) {
stride = stride * op->type.lanes();
}
Expr stride = arith::ComputeReduce<Mul>(op->extents) * op->type.lanes();
Array<Expr> other;
other.push_back(num_threads_);
other.push_back(make_const(op->extents[0].type(), num_threads_));
for (Expr e : extents) {
other.push_back(e);
}
extents = other;
changed = true;
// mark this buffer get touched.
touched_alloc_[op->buffer_var.get()] = stride;
alloc_remap_[op->buffer_var.get()] = stride;
// Mutate the body.
body = Mutate(op->body);
} else {
......@@ -340,6 +374,7 @@ class VTInjector : public IRMutator {
CHECK(!vt_loop_injected_);
// reset the flags
visit_touched_var_ = false;
trigger_base_inject_ = false;
vt_loop_injected_ = true;
if (before_mutation) {
stmt = this->Mutate(stmt);
......@@ -359,7 +394,8 @@ class VTInjector : public IRMutator {
// insert a for loop
Var idx(var_->name_hint + ".s", var_->type);
stmt = Substitute(stmt, {{var_, idx}});
return For::make(idx, 0, num_threads_,
return For::make(idx, make_zero(idx.type()),
make_const(idx.type(), num_threads_),
ForType::Serial, DeviceAPI::None, stmt);
}
}
......@@ -373,12 +409,16 @@ class VTInjector : public IRMutator {
bool vt_loop_injected_{false};
// whether current expression get touched.
bool visit_touched_var_{false};
// Trigger base stmt
bool trigger_base_inject_{false};
// the counter of loops in after mutation.
int max_loop_depth_{0};
// The variables that get touched.
std::unordered_set<const Variable*> touched_var_;
const std::unordered_set<const Variable*>& touched_var_;
// Whether allow shareding.
bool allow_share_;
// The allocations that get touched -> extent
std::unordered_map<const Variable*, Expr> touched_alloc_;
std::unordered_map<const Variable*, Expr> alloc_remap_;
};
......@@ -389,10 +429,11 @@ class VirtualThreadInjector : public IRMutator {
op = stmt.as<AttrStmt>();
if (op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
bool allow_share = iv->thread_tag == "vthread";
int nthread = static_cast<int>(op->value.as<IntImm>()->value);
VarTouchedAnalysis vs;
auto touched = vs.TouchedVar(op->body, iv->var.get());
VTInjector injecter(iv->var, nthread, touched);
VTInjector injecter(iv->var, nthread, touched, allow_share);
return injecter.Mutate(op->body);
} else {
return stmt;
......
......@@ -140,6 +140,8 @@ class BuiltinLower : public IRMutator {
return MakeShape(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
return MakeArray(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
return make_zero(op->type);
} else {
return IRMutator::Mutate_(op, e);
}
......
import tvm
def test_vthread():
dtype = 'int64'
n = 100
m = 4
nthread = 2
def get_vthread(name):
tx = tvm.thread_axis(name)
ty = tvm.thread_axis(name)
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="C")
with ib.for_range(0, n) as i:
ib.scope_attr(tx, "virtual_thread", nthread)
ib.scope_attr(ty, "virtual_thread", nthread)
B = ib.allocate("float32", m, name="B", scope="shared")
B[i] = A[i * nthread + tx]
bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
ib.emit(tvm.call_extern("int32", "Run",
bbuffer.access_ptr("r"),
tvm.call_pure_intrin("int32", "tvm_context_id")))
C[i * nthread + tx] = B[i] + 1
return ib.get()
stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("vthread"))
assert stmt.body.body.extents[0].value == 2
stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("cthread"))
assert len(stmt.body.body.extents) == 3
if __name__ == "__main__":
test_vthread()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment