Commit 3ac94439 by ziheng Committed by Tianqi Chen

[PASS] Support for partition loops with thread_axis (#81)

* [PASS] Support for partition loops with thread_axis

* Add check for AttrStmt.attr_key
parent 0876e9e9
......@@ -39,27 +39,43 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
class PartitionFinder : public IRVisitor {
public:
explicit PartitionFinder(VarExpr loop_var,
explicit PartitionFinder(VarExpr current_var,
const std::unordered_map<const Variable*, IntSet>& dom_map)
: target_var_(loop_var), out_vars_(dom_map.size()), hint_map_(dom_map) {
: current_var_(current_var), out_vars_(dom_map.size()), hint_map_(dom_map) {
for (const auto& kv : dom_map) out_vars_.insert(kv.first);
}
void Visit_(const For* op) {
if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;
hint_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
relax_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
const Variable* var = op->loop_var.get();
hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
IRVisitor::Visit_(op);
relax_map_.erase(op->loop_var.get());
hint_map_.erase(op->loop_var.get());
relax_map_.erase(var);
hint_map_.erase(var);
}
void Visit_(const AttrStmt* op) {
// handle thread_axis
if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const Variable* var = thread_axis->var.get();
IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value));
hint_map_.insert({var, dom});
relax_map_.insert({var, dom});
IRVisitor::Visit_(op);
relax_map_.erase(var);
hint_map_.erase(var);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const IfThenElse* op) {
if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({target_var_.get()}))) {
IntSet interval = DeduceBound(target_var_, op->condition, hint_map_, relax_map_);
if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({current_var_.get()}))) {
IntSet interval = DeduceBound(current_var_, op->condition, hint_map_, relax_map_);
partitions[op->condition.get()] = Partition{op->condition, interval};
} else {
IRVisitor::Visit_(op);
......@@ -69,7 +85,7 @@ class PartitionFinder : public IRVisitor {
std::unordered_map<const Node*, Partition> partitions;
private:
VarExpr target_var_;
VarExpr current_var_;
std::unordered_set<const Variable*> out_vars_;
std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_;
......
......@@ -53,8 +53,28 @@ def test_multi_if():
assert('if' not in str(stmt.body.first))
print(stmt)
def test_thread_axis():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
s = tvm.Schedule(B.op)
s[B].set_scope("shared")
num_thread = 16
xo, xi = s[B].split(B.op.axis[0], 32)
xi0, xi1 = s[B].split(xi, nparts=num_thread)
s[B].bind(xi0, tvm.thread_axis("threadIdx.x"))
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt_ = tvm.ir_pass.LoopPartition(stmt)
assert('if' not in str(stmt_.body.body.body.first))
print(stmt_)
if __name__ == "__main__":
test_basic()
test_multi_loop()
test_multi_if()
test_thread_axis()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment