Commit 4d2fc952 by Tianqi Chen Committed by GitHub

[PASS] Fix vthread when extern access touching (#636)

parent b07ceff5
......@@ -15,11 +15,12 @@ namespace ir {
// If expression is touched by var.
class ExprTouched final : public IRVisitor {
public:
explicit ExprTouched(const std::unordered_set<const Variable*> &touched)
: touched_var_(touched) {}
explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
bool check_write)
: touched_var_(touched), check_write_(check_write) {}
void Visit(const NodeRef& n) final {
// early stopping
if (expr_touched_) return;
if (expr_touched_ && !check_write_) return;
IRVisitor::Visit(n);
}
void Visit_(const Load *op) final {
......@@ -29,6 +30,24 @@ class ExprTouched final : public IRVisitor {
void Visit_(const Variable *op) final {
HandleUseVar(op);
}
void Visit_(const Call *op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
int rw_mask;
CHECK(arith::GetConstInt(op->args[4], &rw_mask));
const Variable* buffer_var = op->args[1].as<Variable>();
CHECK(buffer_var);
// read
if (rw_mask & 1) {
HandleUseVar(buffer_var);
}
if (rw_mask & 2) {
HandleWriteVar(buffer_var);
}
this->Visit(op->args[2]);
} else {
IRVisitor::Visit_(op);
}
}
void HandleUseVar(const Variable* var) {
auto it = touched_var_.find(var);
if (it != touched_var_.end()) {
......@@ -40,36 +59,49 @@ class ExprTouched final : public IRVisitor {
used_vars_.push_back(var);
}
}
void HandleWriteVar(const Variable* var) {
write_vars_.push_back(var);
}
// the fields.
bool expr_touched_{false};
std::vector<const Variable*> used_vars_;
std::vector<const Variable*> write_vars_;
const std::unordered_set<const Variable*>& touched_var_;
bool check_write_;
};
// Analyze if the buffers are invariant to value of var
class VarTouchedAnalysis : public IRVisitor {
public:
void Visit_(const LetStmt *op) {
ExprTouched tc(touched_var_);
ExprTouched tc(touched_var_, false);
tc.Visit(op->value);
Record(op->var.get(), tc);
this->Visit(op->body);
}
void Visit_(const Store *op) {
ExprTouched tc(touched_var_);
ExprTouched tc(touched_var_, false);
tc.Visit(op->value);
tc.Visit(op->index);
Record(op->buffer_var.get(), tc);
}
void Visit_(const For *op) {
ExprTouched tc(touched_var_);
ExprTouched tc(touched_var_, false);
tc.Visit(op->min);
tc.Visit(op->extent);
Record(op->loop_var.get(), tc);
this->Visit(op->body);
}
// external function call
void Visit_(const Evaluate *op) {
ExprTouched tc(touched_var_, true);
tc.Visit(op->value);
for (const Variable* var : tc.write_vars_) {
Record(var, tc);
}
}
void Visit_(const Allocate *op) {
ExprTouched tc(touched_var_);
ExprTouched tc(touched_var_, false);
for (size_t i = 0; i < op->extents.size(); ++i) {
tc.Visit(op->extents[i]);
}
......@@ -87,7 +119,9 @@ class VarTouchedAnalysis : public IRVisitor {
touched_var_.insert(var);
} else {
for (const Variable* r : tc.used_vars_) {
affect_[r].push_back(var);
if (r != var) {
affect_[r].push_back(var);
}
}
}
}
......
......@@ -28,5 +28,39 @@ def test_vthread():
stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("cthread"))
assert len(stmt.body.body.extents) == 3
def test_vthread_extern():
dtype = 'int64'
n = 100
m = 4
nthread = 2
def get_vthread(name):
tx = tvm.thread_axis(name)
ty = tvm.thread_axis(name)
ib = tvm.ir_builder.create()
with ib.for_range(0, n) as i:
ib.scope_attr(tx, "virtual_thread", nthread)
ib.scope_attr(ty, "virtual_thread", nthread)
A = ib.allocate("float32", m, name="A", scope="shared")
B = ib.allocate("float32", m, name="B", scope="shared")
C = ib.allocate("float32", m, name="C", scope="shared")
cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode())
abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode())
bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
A[tx] = tx + 1.0
B[ty] = ty + 1.0
ib.emit(tvm.call_extern("int32", "Run",
abuffer.access_ptr("r"),
bbuffer.access_ptr("r"),
cbuffer.access_ptr("rw")))
return ib.get()
stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("vthread"))
assert stmt.body.body.extents[0].value == 2
assert stmt.body.body.body.body.body.body.extents[0].value == 2
assert len(stmt.body.body.body.body.body.body.extents) == 3
if __name__ == "__main__":
test_vthread_extern()
test_vthread()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment