Unverified Commit 3d09e64d by yongfeng-nv Committed by GitHub

Create loops according to storage scope and thread hierarchies (#5190)

* Set IterVar index to 0 for local thread bound IterVars.

* Lint fix

* Use rank instead of scope name to predicate.  Add tests.

* Handle cases other than local/threadIdx.

* Turn warp to the old behavior.

* Modify test to cover global/blockIdx.

* Fix a typo.

* Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
parent a4321e03
......@@ -29,6 +29,7 @@
#include "op_util.h"
#include "../schedule/message_passing.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace te {
......@@ -162,7 +163,13 @@ MakeLoopNest(const Stage& stage,
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = var;
runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag);
if (stage->scope == "" || stage->scope == "warp" ||
static_cast<int>(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) {
value_map[iv] = var;
} else {
value_map[iv] = dom->min;
}
}
}
// annotate the extent of the IterVar
......
......@@ -482,6 +482,92 @@ def test_schedule_compute_inline():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
def test_local_stage_predicate():
m = 1
n = 3
p = 2
A = tvm.te.placeholder((m, n, p), name='A')
B = tvm.te.compute((m, n, p), lambda bi, bj, bk: A[bi, bj, bk], name="B")
C = tvm.te.compute((m, n, p), lambda ci, cj, ck: B[ci, cj, ck], name="C")
by = tvm.te.thread_axis("blockIdx.y")
tx = tvm.te.thread_axis("threadIdx.x")
vx = tvm.te.thread_axis("vthread")
def schedule(thread_tag, mem_scope) :
s = tvm.te.create_schedule(C.op)
s[B].compute_at(s[C], s[C].op.axis[0])
s[B].set_scope(mem_scope)
bno, bni = s[B].split(s[B].op.axis[1], n)
bx = tvm.te.thread_axis("blockIdx.x")
s[C].bind(s[C].op.axis[0], bx)
s[C].bind(s[C].op.axis[1], thread_tag)
s[B].bind(bni, thread_tag)
return s
def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
return ret
# local vs. threadIdx
s = schedule(tx, "local")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
# local vs. vthread
s = schedule(vx, "local")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
# shared vs. blockIdx
s = schedule(by, "shared")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
def test_local_stage_predicate2():
A = tvm.te.placeholder((128, ), name="A")
B = tvm.te.compute((128, ), lambda bi: A[bi] + 1, name="B")
C = tvm.te.compute((128, ), lambda ci: B[ci] + 2, name="C")
s = tvm.te.create_schedule(C.op)
AA = s.cache_read(A, "local", [B])
s[B].set_scope("shared")
block_x = tvm.te.thread_axis("blockIdx.x")
thread_x = tvm.te.thread_axis((0, 32), "threadIdx.x")
oc, ic = s[C].split(s[C].op.axis[0], factor=64)
ooc, ioc = s[C].split(oc, factor=2)
oic, iic = s[C].split(ic, factor=32)
s[C].bind(ooc, block_x)
s[C].bind(iic, thread_x)
s[B].compute_at(s[C], ioc)
ob, ib = s[B].split(s[B].op.axis[0], factor=32)
s[B].bind(ib, thread_x)
s[AA].compute_root()
s[AA].compute_at(s[C], ooc)
oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32)
s[AA].bind(iaa, thread_x)
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
return ret
def visit_stmt(op):
print(op)
if (isinstance(op, tvm.tir.Allocate)):
return op.extents[0].value == 97
return False
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
assert (any(collect_visit(lowered_body, visit_stmt)))
if __name__ == "__main__":
test_loop_dep_reduce()
test_loop_dep_reduce_cache_write()
......@@ -506,3 +592,5 @@ if __name__ == "__main__":
test_schedule_tensor_compute3()
test_reduction_and_dummy_fuse_split()
test_schedule_compute_inline()
test_local_stage_predicate()
test_local_stage_predicate2()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment