Commit ba7b9ddd by Tianqi Chen Committed by GitHub

[PASS] Enable StorageRewrite before virtual thread lowering (#880)

* [PASS] Enable StorageRewrite before virtual thread lowering

* update

* fix testcase
parent 0a410a39
......@@ -154,6 +154,8 @@ class LinearAccessPatternFinder final : public IRVisitor {
in_thread_env_ = false;
} else if (op->attr_key == attr::extern_scope) {
VisitNewScope(op);
} else if (op->attr_key == attr::virtual_thread) {
VisitNewScope(op);
} else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
alloc_info_[buf].storage_scope =
......@@ -395,11 +397,10 @@ class StoragePlanRewriter : public IRMutator {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->attr_key != attr::virtual_thread)
<< "InjectVirtualThread before StoragePlan";
if (op->attr_key == attr::storage_scope) {
return this->Mutate(op->body);
} else if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread ||
op->attr_key == attr::pragma_scope) {
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
......@@ -481,11 +482,13 @@ class StoragePlanRewriter : public IRMutator {
Stmt body) {
std::vector<Stmt> nest;
for (StorageEntry* e : svec) {
nest.emplace_back(AttrStmt::make(
e->alloc_var, attr::storage_scope,
StringImm::make(e->scope.to_string()),
Evaluate::make(0)));
nest.push_back(e->new_alloc);
if (e->new_alloc.defined()) {
nest.emplace_back(AttrStmt::make(
e->alloc_var, attr::storage_scope,
StringImm::make(e->scope.to_string()),
Evaluate::make(0)));
nest.push_back(e->new_alloc);
}
}
return MergeNest(nest, body);
}
......@@ -716,7 +719,8 @@ class StoragePlanRewriter : public IRMutator {
if (s.stmt->is_type<AttrStmt>()) {
const auto* op = static_cast<const AttrStmt*>(s.stmt);
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pragma_scope) {
op->attr_key == attr::pragma_scope ||
op->attr_key == attr::virtual_thread) {
PlanNewScope(op);
} else {
CHECK(op->attr_key == attr::extern_scope);
......
......@@ -77,7 +77,22 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
for _ in range(1):
foo(data_tvm, out_tvm)
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
if type == "argmax" or type == "argmin":
out_tvm_indices = out_tvm.asnumpy()
if keepdims:
out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis)
if axis is None:
out_tvm_val = in_npy_map.ravel()[out_tvm_indices]
else:
other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis+1):]))
sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:]
out_tvm_val = in_npy_map[sel_indices]
if type == "argmax":
np.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1E-3, 1E-3)
elif type == "argmin":
np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
else:
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
for device in ["cuda", "opencl", "metal", "llvm", "rocm"]:
check_device(device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment