Commit 9ec40edd by Tianqi Chen Committed by GitHub

[SCHEDULE] Fix cross thread schedule after refactor (#85)

parent ea9c1c59
...@@ -194,29 +194,25 @@ Stmt Substitute(Stmt s, ...@@ -194,29 +194,25 @@ Stmt Substitute(Stmt s,
// Cross Thread reduction marker. // Cross Thread reduction marker.
bool IsCrossThreadReduction(const ComputeOpNode* self, bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) { const Stage& stage) {
std::unordered_set<IterVar> rebase_thread;
for (IterVarRelation rel : stage->relations) {
if (const RebaseNode* s = rel.as<RebaseNode>()) {
if (s->parent->iter_type == kCommReduce &&
s->rebased->iter_type == kThreadIndex) {
rebase_thread.insert(s->rebased);
}
}
}
if (rebase_thread.size() == 0) return false;
// Verify correctness of leaf nest. // Verify correctness of leaf nest.
bool reduce_start = false; int normal_red = 0, thread_red = 0;
for (IterVar iv : stage->leaf_iter_vars) { for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) { if (iv->iter_type == kCommReduce) {
LOG(FATAL) << "Cannot mix cross thread reduce with normal reduce"; auto it = stage->iter_var_attrs.find(iv);
} else if (rebase_thread.count(iv)) { if (it != stage->iter_var_attrs.end() &&
reduce_start = true; (*it).second->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else { } else {
CHECK(!reduce_start) CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis"; << "Cross thread reduce cannot swap with normal data axis";
} }
} }
return true; CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
return thread_red != 0;
} }
Stmt MakeCrossThreadReduction( Stmt MakeCrossThreadReduction(
...@@ -246,12 +242,14 @@ Stmt MakeCrossThreadReduction( ...@@ -246,12 +242,14 @@ Stmt MakeCrossThreadReduction(
freduce_args.push_back(cond); freduce_args.push_back(cond);
std::vector<Expr> thread_head_check; std::vector<Expr> thread_head_check;
for (IterVarRelation rel : stage->relations) { for (IterVar iv : stage->leaf_iter_vars) {
if (const RebaseNode* s = rel.as<RebaseNode>()) { if (iv->iter_type == kCommReduce) {
if (s->parent->iter_type == kCommReduce && auto it = stage->iter_var_attrs.find(iv);
s->rebased->iter_type == kThreadIndex) { if (it != stage->iter_var_attrs.end() &&
freduce_args.push_back(s->rebased->var); (*it).second->bind_thread.defined()) {
thread_head_check.push_back(s->rebased->var == 0); IterVar tv = (*it).second->bind_thread;
freduce_args.push_back(tv->var);
thread_head_check.push_back(tv->var == 0);
} }
} }
} }
......
...@@ -99,13 +99,14 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -99,13 +99,14 @@ class ThreadAllreduceBuilder : public IRMutator {
cond, value, Reduce::InitValue(op_code, value.type())); cond, value, Reduce::InitValue(op_code, value.type()));
} }
std::unordered_set<const Variable*> reduce_index_; std::unordered_set<const Variable*> reduce_set;
for (size_t i = 3; i < call->args.size(); ++i) { for (size_t i = 3; i < call->args.size(); ++i) {
const Variable* v = call->args[i].as<Variable>(); const Variable* v = call->args[i].as<Variable>();
CHECK(v); CHECK(v);
reduce_index_.insert(v); reduce_set.insert(v);
} }
size_t nmatch = 0; size_t nmatch = 0;
std::unordered_set<const Variable*> visited;
std::vector<ThreadEntry> vred, vpar; std::vector<ThreadEntry> vred, vpar;
for (const AttrStmt* attr : thread_extents_) { for (const AttrStmt* attr : thread_extents_) {
ThreadEntry e; ThreadEntry e;
...@@ -118,15 +119,18 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -118,15 +119,18 @@ class ThreadAllreduceBuilder : public IRMutator {
CHECK_GE(e.scope.dim_index, 0) CHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction"; << "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) { if (e.scope.rank == 1) {
if (reduce_index_.count(iv->var.get())) { if (!visited.count(iv->var.get())) {
vred.push_back(e); visited.insert(iv->var.get());
++nmatch; if (reduce_set.count(iv->var.get())) {
} else { vred.push_back(e);
vpar.push_back(e); ++nmatch;
} else {
vpar.push_back(e);
}
} }
} }
} }
CHECK_EQ(nmatch, reduce_index_.size()) CHECK_EQ(nmatch, reduce_set.size())
<< "Not all reduce index are presented in the context"; << "Not all reduce index are presented in the context";
std::sort(vred.begin(), vred.end()); std::sort(vred.begin(), vred.end());
std::sort(vpar.begin(), vpar.end()); std::sort(vpar.begin(), vpar.end());
......
...@@ -128,7 +128,11 @@ void InferRootBound(const Stage& stage, ...@@ -128,7 +128,11 @@ void InferRootBound(const Stage& stage,
CHECK(is_zero(vrange->min)) CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, " << "InferBound requires every leaf iter var's min equals 0, "
<< " call schedule.normalize to achieve this. "; << " call schedule.normalize to achieve this. ";
up_state[iv] = IntSet::single_point(iv->var); if (ctx.bind_map.count(iv)) {
up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
} else {
up_state[iv] = IntSet::single_point(iv->var);
}
} else { } else {
up_state[iv] = IntSet::range(vrange); up_state[iv] = IntSet::range(vrange);
} }
......
...@@ -161,6 +161,12 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -161,6 +161,12 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) { for (IterVar iv : root_iter_vars) {
size_t idx = FindNodeRef(leaf_vars, iv); size_t idx = FindNodeRef(leaf_vars, iv);
auto it = s->iter_var_attrs.find(iv);
// don;t need to rebase path that are binded.
if (it != s->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
continue;
}
if (idx < leaf_vars->data.size()) { if (idx < leaf_vars->data.size()) {
// insert rebase // insert rebase
IterVar rebased = IterVarNode::make( IterVar rebased = IterVarNode::make(
...@@ -364,6 +370,10 @@ Tensor Schedule::rfactor(const Tensor& tensor, ...@@ -364,6 +370,10 @@ Tensor Schedule::rfactor(const Tensor& tensor,
stages->data.insert(stages->data.begin() + stage_pos, stages->data.insert(stages->data.begin() + stage_pos,
factor_stage.node_); factor_stage.node_);
(*this)->stage_map.Set(factor_op, factor_stage); (*this)->stage_map.Set(factor_op, factor_stage);
factor_stage->group = reduce_stage->group;
if (factor_stage->group.defined()) {
++factor_stage->group->num_child_stages;
}
// Replace the old reduction. // Replace the old reduction.
IterVar repl_red_axis = reduce_axis( IterVar repl_red_axis = reduce_axis(
dom_map.at(axis), axis->var->name_hint + ".v"); dom_map.at(axis), axis->var->name_hint + ".v");
......
...@@ -90,10 +90,11 @@ def test_rfactor_threads(): ...@@ -90,10 +90,11 @@ def test_rfactor_threads():
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
ko, kf = s[B].split(k, factor=nthread) ko, kf = s[B].split(k, factor=nthread)
BF = s.rfactor(B, kf) BF = s.rfactor(B, kf)
bx, tx = s[B].split(s[B].op.axis[0], factor=nthread) bx, ty = s[B].split(s[B].op.axis[0], factor=nthread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x")) s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.y")) s[B].bind(ty, tvm.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x")) tx = s[B].op.reduce_axis[0]
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
s[BF].compute_at(s[B], tx) s[BF].compute_at(s[B], tx)
# one line to build the function. # one line to build the function.
...@@ -124,6 +125,6 @@ def test_rfactor_threads(): ...@@ -124,6 +125,6 @@ def test_rfactor_threads():
check_target("opencl") check_target("opencl")
if __name__ == "__main__": if __name__ == "__main__":
test_rfactor()
test_rfactor_threads() test_rfactor_threads()
test_rfactor()
test_sum() test_sum()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment