Commit ba7b9ddd by Tianqi Chen Committed by GitHub

[PASS] Enable StorageRewrite before virtual thread lowering (#880)

* [PASS] Enable StorageRewrite before virtual thread lowering

* update

* fix testcase
parent 0a410a39
...@@ -154,6 +154,8 @@ class LinearAccessPatternFinder final : public IRVisitor { ...@@ -154,6 +154,8 @@ class LinearAccessPatternFinder final : public IRVisitor {
in_thread_env_ = false; in_thread_env_ = false;
} else if (op->attr_key == attr::extern_scope) { } else if (op->attr_key == attr::extern_scope) {
VisitNewScope(op); VisitNewScope(op);
} else if (op->attr_key == attr::virtual_thread) {
VisitNewScope(op);
} else if (op->attr_key == attr::storage_scope) { } else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>(); const Variable* buf = op->node.as<Variable>();
alloc_info_[buf].storage_scope = alloc_info_[buf].storage_scope =
...@@ -395,11 +397,10 @@ class StoragePlanRewriter : public IRMutator { ...@@ -395,11 +397,10 @@ class StoragePlanRewriter : public IRMutator {
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->attr_key != attr::virtual_thread)
<< "InjectVirtualThread before StoragePlan";
if (op->attr_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
return this->Mutate(op->body); return this->Mutate(op->body);
} else if (op->attr_key == attr::thread_extent || } else if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread ||
op->attr_key == attr::pragma_scope) { op->attr_key == attr::pragma_scope) {
// remake all the allocation at the attach scope. // remake all the allocation at the attach scope.
if (attach_map_.count(op)) { if (attach_map_.count(op)) {
...@@ -481,12 +482,14 @@ class StoragePlanRewriter : public IRMutator { ...@@ -481,12 +482,14 @@ class StoragePlanRewriter : public IRMutator {
Stmt body) { Stmt body) {
std::vector<Stmt> nest; std::vector<Stmt> nest;
for (StorageEntry* e : svec) { for (StorageEntry* e : svec) {
if (e->new_alloc.defined()) {
nest.emplace_back(AttrStmt::make( nest.emplace_back(AttrStmt::make(
e->alloc_var, attr::storage_scope, e->alloc_var, attr::storage_scope,
StringImm::make(e->scope.to_string()), StringImm::make(e->scope.to_string()),
Evaluate::make(0))); Evaluate::make(0)));
nest.push_back(e->new_alloc); nest.push_back(e->new_alloc);
} }
}
return MergeNest(nest, body); return MergeNest(nest, body);
} }
// Remap the index // Remap the index
...@@ -716,7 +719,8 @@ class StoragePlanRewriter : public IRMutator { ...@@ -716,7 +719,8 @@ class StoragePlanRewriter : public IRMutator {
if (s.stmt->is_type<AttrStmt>()) { if (s.stmt->is_type<AttrStmt>()) {
const auto* op = static_cast<const AttrStmt*>(s.stmt); const auto* op = static_cast<const AttrStmt*>(s.stmt);
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pragma_scope) { op->attr_key == attr::pragma_scope ||
op->attr_key == attr::virtual_thread) {
PlanNewScope(op); PlanNewScope(op);
} else { } else {
CHECK(op->attr_key == attr::extern_scope); CHECK(op->attr_key == attr::extern_scope);
......
...@@ -77,6 +77,21 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -77,6 +77,21 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype) out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
for _ in range(1): for _ in range(1):
foo(data_tvm, out_tvm) foo(data_tvm, out_tvm)
if type == "argmax" or type == "argmin":
out_tvm_indices = out_tvm.asnumpy()
if keepdims:
out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis)
if axis is None:
out_tvm_val = in_npy_map.ravel()[out_tvm_indices]
else:
other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis+1):]))
sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:]
out_tvm_val = in_npy_map[sel_indices]
if type == "argmax":
np.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1E-3, 1E-3)
elif type == "argmin":
np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
else:
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3) np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
for device in ["cuda", "opencl", "metal", "llvm", "rocm"]: for device in ["cuda", "opencl", "metal", "llvm", "rocm"]:
check_device(device) check_device(device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment