Commit ea9c1c59 by Tianqi Chen Committed by GitHub

[SCHEDULE] More reliable bound inference on threading. (#84)

parent 3ac94439
...@@ -154,24 +154,9 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -154,24 +154,9 @@ Tensor Schedule::cache_write(const Tensor& tensor,
void RebaseNonZeroMinLoop(const Schedule& sch) { void RebaseNonZeroMinLoop(const Schedule& sch) {
std::unordered_map<IterVar, IterVar> rebase_map; std::unordered_map<IterVar, IterVar> rebase_map;
std::unordered_map<const Node*, int> attach_mark;
for (Stage s : sch->stages) { for (Stage s : sch->stages) {
if (s->attach_type == kScope) { if (s->attach_type == kInlinedAlready) continue;
attach_mark[s->attach_stage.get()] = 1;
}
if (s->op.as<ScanOpNode>()) {
attach_mark[s.get()] = 1;
}
}
for (Stage s : sch->groups) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
}
for (Stage s : sch->stages) {
if (!attach_mark.count(s.get())) continue;
auto root_iter_vars = s->op->root_iter_vars(); auto root_iter_vars = s->op->root_iter_vars();
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) { for (IterVar iv : root_iter_vars) {
...@@ -201,16 +186,6 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -201,16 +186,6 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
} }
} }
void SetScanAttach(const Schedule& sch) { // NOLINT(*)
for (Stage stage : sch->stages) {
if (stage->attach_type == kScanUpdate) {
const Stage& parent = stage->attach_stage;
stage->attach_ivar =
parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1];
}
}
}
void InjectInline(ScheduleNode* sch) { void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache(); sch->InvalidateCache();
std::vector<Expr> new_body(sch->stages.size()); std::vector<Expr> new_body(sch->stages.size());
...@@ -262,9 +237,8 @@ void InjectInline(ScheduleNode* sch) { ...@@ -262,9 +237,8 @@ void InjectInline(ScheduleNode* sch) {
} }
void Schedule::normalize() { void Schedule::normalize() {
RebaseNonZeroMinLoop(*this);
SetScanAttach(*this);
InjectInline(operator->()); InjectInline(operator->());
RebaseNonZeroMinLoop(*this);
} }
// Handle reduction factor. // Handle reduction factor.
......
...@@ -148,7 +148,37 @@ def test_bound_nest_group(): ...@@ -148,7 +148,37 @@ def test_bound_nest_group():
assert bounds[x1.op.axis[0]].extent.value == 1 assert bounds[x1.op.axis[0]].extent.value == 1
assert bounds[x1.op.axis[1]].extent == n assert bounds[x1.op.axis[1]].extent == n
def test_bound_nest_thread():
m = tvm.Var('m')
A = tvm.placeholder((m), name='A')
A1 = tvm.compute((m,), lambda i: A[i], name='A1')
A2 = tvm.compute((m,), lambda i: A1[i] + 2, name='A2')
A3 = tvm.compute((m,), lambda i: A2[i] + 3, name='A3')
s = tvm.Schedule(A3.op)
s[A2].set_scope("shared")
s[A1].set_scope("local")
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
bx, tx = s[A3].split(A3.op.axis[0], factor=32)
s[A3].bind(bx, block_x)
s[A3].bind(tx, thread_x)
s[A2].compute_at(s[A3], tx)
_, xi = s[A2].split(A2.op.axis[0], nparts=1)
s[A2].bind(xi, thread_x)
s[A1].compute_at(s[A3], tx)
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[A1.op.axis[0]].extent.value==1)
assert(bounds[A2.op.axis[0]].extent.value==32)
assert(bounds[A3.op.axis[0]].extent == m)
if __name__ == "__main__": if __name__ == "__main__":
test_bound_nest_thread()
test_bound1()
test_bound_nest_group() test_bound_nest_group()
test_bound_group_schedule() test_bound_group_schedule()
test_bound_scan() test_bound_scan()
...@@ -156,5 +186,4 @@ if __name__ == "__main__": ...@@ -156,5 +186,4 @@ if __name__ == "__main__":
test_bound_rfactor() test_bound_rfactor()
test_bound_blur() test_bound_blur()
test_bound_conv1d() test_bound_conv1d()
test_bound1()
test_bound2() test_bound2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment