Commit 9ec40edd by Tianqi Chen Committed by GitHub

[SCHEDULE] Fix cross thread schedule after refactor (#85)

parent ea9c1c59
......@@ -194,29 +194,25 @@ Stmt Substitute(Stmt s,
// Cross Thread reduction marker.
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) {
std::unordered_set<IterVar> rebase_thread;
for (IterVarRelation rel : stage->relations) {
if (const RebaseNode* s = rel.as<RebaseNode>()) {
if (s->parent->iter_type == kCommReduce &&
s->rebased->iter_type == kThreadIndex) {
rebase_thread.insert(s->rebased);
}
}
}
if (rebase_thread.size() == 0) return false;
// Verify correctness of leaf nest.
bool reduce_start = false;
int normal_red = 0, thread_red = 0;
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
LOG(FATAL) << "Cannot mix cross thread reduce with normal reduce";
} else if (rebase_thread.count(iv)) {
reduce_start = true;
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK(!reduce_start)
CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
return true;
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
return thread_red != 0;
}
Stmt MakeCrossThreadReduction(
......@@ -246,12 +242,14 @@ Stmt MakeCrossThreadReduction(
freduce_args.push_back(cond);
std::vector<Expr> thread_head_check;
for (IterVarRelation rel : stage->relations) {
if (const RebaseNode* s = rel.as<RebaseNode>()) {
if (s->parent->iter_type == kCommReduce &&
s->rebased->iter_type == kThreadIndex) {
freduce_args.push_back(s->rebased->var);
thread_head_check.push_back(s->rebased->var == 0);
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
IterVar tv = (*it).second->bind_thread;
freduce_args.push_back(tv->var);
thread_head_check.push_back(tv->var == 0);
}
}
}
......
......@@ -99,13 +99,14 @@ class ThreadAllreduceBuilder : public IRMutator {
cond, value, Reduce::InitValue(op_code, value.type()));
}
std::unordered_set<const Variable*> reduce_index_;
std::unordered_set<const Variable*> reduce_set;
for (size_t i = 3; i < call->args.size(); ++i) {
const Variable* v = call->args[i].as<Variable>();
CHECK(v);
reduce_index_.insert(v);
reduce_set.insert(v);
}
size_t nmatch = 0;
std::unordered_set<const Variable*> visited;
std::vector<ThreadEntry> vred, vpar;
for (const AttrStmt* attr : thread_extents_) {
ThreadEntry e;
......@@ -118,15 +119,18 @@ class ThreadAllreduceBuilder : public IRMutator {
CHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
if (reduce_index_.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
} else {
vpar.push_back(e);
if (!visited.count(iv->var.get())) {
visited.insert(iv->var.get());
if (reduce_set.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
} else {
vpar.push_back(e);
}
}
}
}
CHECK_EQ(nmatch, reduce_index_.size())
CHECK_EQ(nmatch, reduce_set.size())
<< "Not all reduce index are presented in the context";
std::sort(vred.begin(), vred.end());
std::sort(vpar.begin(), vpar.end());
......
......@@ -128,7 +128,11 @@ void InferRootBound(const Stage& stage,
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< " call schedule.normalize to achieve this. ";
up_state[iv] = IntSet::single_point(iv->var);
if (ctx.bind_map.count(iv)) {
up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
} else {
up_state[iv] = IntSet::single_point(iv->var);
}
} else {
up_state[iv] = IntSet::range(vrange);
}
......
......@@ -161,6 +161,12 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) {
size_t idx = FindNodeRef(leaf_vars, iv);
auto it = s->iter_var_attrs.find(iv);
// don;t need to rebase path that are binded.
if (it != s->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
continue;
}
if (idx < leaf_vars->data.size()) {
// insert rebase
IterVar rebased = IterVarNode::make(
......@@ -364,6 +370,10 @@ Tensor Schedule::rfactor(const Tensor& tensor,
stages->data.insert(stages->data.begin() + stage_pos,
factor_stage.node_);
(*this)->stage_map.Set(factor_op, factor_stage);
factor_stage->group = reduce_stage->group;
if (factor_stage->group.defined()) {
++factor_stage->group->num_child_stages;
}
// Replace the old reduction.
IterVar repl_red_axis = reduce_axis(
dom_map.at(axis), axis->var->name_hint + ".v");
......
......@@ -90,10 +90,11 @@ def test_rfactor_threads():
s = tvm.Schedule(B.op)
ko, kf = s[B].split(k, factor=nthread)
BF = s.rfactor(B, kf)
bx, tx = s[B].split(s[B].op.axis[0], factor=nthread)
bx, ty = s[B].split(s[B].op.axis[0], factor=nthread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x"))
s[B].bind(ty, tvm.thread_axis("threadIdx.y"))
tx = s[B].op.reduce_axis[0]
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
s[BF].compute_at(s[B], tx)
# one line to build the function.
......@@ -124,6 +125,6 @@ def test_rfactor_threads():
check_target("opencl")
if __name__ == "__main__":
test_rfactor()
test_rfactor_threads()
test_rfactor()
test_sum()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment