Unverified Commit f31df01e by Tang, Shizhi Committed by GitHub

fix lower_warp_memory (#5247)

parent 3e8c7beb
...@@ -219,13 +219,13 @@ class WarpAccessRewriter : protected StmtExprMutator { ...@@ -219,13 +219,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
} }
protected: protected:
PrimExpr Mutate_(const VarNode* op) { PrimExpr VisitExpr_(const VarNode* op) override {
CHECK(op != buffer_) CHECK(op != buffer_)
<< "Cannot access address of warp memory directly"; << "Cannot access address of warp memory directly";
return StmtExprMutator::VisitExpr_(op); return StmtExprMutator::VisitExpr_(op);
} }
Stmt VisitStmt_(const StoreNode* op) { Stmt VisitStmt_(const StoreNode* op) override {
if (op->buffer_var.get() == buffer_) { if (op->buffer_var.get() == buffer_) {
PrimExpr local_index, group; PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index); std::tie(local_index, group) = SplitIndexByGroup(op->index);
...@@ -235,7 +235,7 @@ class WarpAccessRewriter : protected StmtExprMutator { ...@@ -235,7 +235,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
} }
} }
PrimExpr Mutate_(const LoadNode* op) { PrimExpr VisitExpr_(const LoadNode* op) override {
if (op->buffer_var.get() == buffer_) { if (op->buffer_var.get() == buffer_) {
PrimExpr local_index, group; PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index); std::tie(local_index, group) = SplitIndexByGroup(op->index);
......
...@@ -16,8 +16,11 @@ ...@@ -16,8 +16,11 @@
# under the License. # under the License.
import tvm import tvm
from tvm import te from tvm import te
from tvm.contrib.nvcc import have_fp16
def test_lower_warp_mem(): import numpy as np
def test_lower_warp_memory_local_scope():
m = 128 m = 128
A = te.placeholder((m,), name='A') A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i] + 3, name='B') B = te.compute((m,), lambda i: A[i] + 3, name='B')
...@@ -44,6 +47,50 @@ def test_lower_warp_mem(): ...@@ -44,6 +47,50 @@ def test_lower_warp_mem():
assert(fdevice.body.body.value.value == "local") assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2) assert(fdevice.body.body.body.extents[0].value == 2)
def test_lower_warp_memory_cuda_end_to_end():
def check_cuda(dtype):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
m = 128
A = te.placeholder((m,), name='A', dtype=dtype)
B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], name='B')
cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32
with cuda_target:
s = te.create_schedule(B.op)
AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], 64)
xi0, xi1 = s[B].split(xi, factor=32)
tx = te.thread_axis("threadIdx.x")
s[B].bind(xi1, tx)
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx)
ctx = tvm.gpu(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.array(list(range(m)), dtype=dtype)
B_np = np.array(
list(range(1, 32)) + [0] +
list(range(33, 64)) + [32] +
list(range(65, 96)) + [64] +
list(range(97, 128)) + [96],
dtype=dtype)
A_nd = tvm.nd.array(A_np, ctx)
B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
func(A_nd, B_nd)
tvm.testing.assert_allclose(B_nd.asnumpy(), B_np, rtol=1e-3)
check_cuda("float32")
check_cuda("float16")
if __name__ == "__main__": if __name__ == "__main__":
test_lower_warp_mem() test_lower_warp_memory_local_scope()
test_lower_warp_memory_cuda_end_to_end()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment