Commit 03b09f74 by Tianqi Chen Committed by GitHub

[PASS] Improve SSA conversion, add forbid list in loop-par (#142)

parent 867ad378
......@@ -90,10 +90,10 @@ def lower(sch,
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.StorageRewrite(stmt)
......
......@@ -52,7 +52,7 @@ class CandidateSelector : public IRVisitor {
const Variable* var = op->loop_var.get();
record_.insert({var, false});
IRVisitor::Visit_(op);
if (record_.at(var)) {
if (record_.at(var) && !no_split_) {
candidates.insert(op);
}
record_.erase(var);
......@@ -70,7 +70,7 @@ class CandidateSelector : public IRVisitor {
if ((scope.rank == 0) && !is_const(op->value)) {
record_.insert({var.get(), false});
IRVisitor::Visit_(op);
if (record_.at(var.get())) {
if (record_.at(var.get()) && !no_split_) {
candidates.insert(op);
}
record_.erase(var.get());
......@@ -80,11 +80,25 @@ class CandidateSelector : public IRVisitor {
IRVisitor::Visit_(op);
}
void Visit_(const Block* op) {
bool temp = no_split_;
this->Visit(op->first);
// erase the no split state of first when visit rest.
std::swap(temp, no_split_);
this->Visit(op->rest);
// restore the no split flag.
no_split_ = no_split_ || temp;
}
void Visit_(const Call* op) {
if (op->is_intrinsic(Call::likely)) {
in_likely_ = true;
IRVisitor::Visit_(op);
in_likely_ = false;
} else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
// no split if the body contains allreduce.
no_split_ = true;
return;
} else {
IRVisitor::Visit_(op);
}
......@@ -100,6 +114,7 @@ class CandidateSelector : public IRVisitor {
private:
bool in_likely_;
bool no_split_{false};
std::unordered_map<const Variable*, VarIsUsed> record_;
};
......
/*!
* Copyright (c) 2016 by Contributors
* SSA related checks and pass.
*
* SSA requires each varaible to be only defined once.
* \file ssa.cc
*/
#include <tvm/ir.h>
......@@ -14,138 +16,155 @@
namespace tvm {
namespace ir {
namespace {
// global functor to get var definition from
struct FGetVarDef {
using FType = IRFunctor<VarExpr (const NodeRef&)>;
static FType& vtable() { // NOLINT(*)
static FType inst; return inst;
}
};
TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable)
.set_dispatch<Let>([](const Let* op) {
return op->var;
})
.set_dispatch<LetStmt>([](const LetStmt* op) {
return op->var;
})
.set_dispatch<For>([](const For* op) {
return op->loop_var;
})
.set_dispatch<Allocate>([](const Allocate* op) {
return op->buffer_var;
});
struct FSetVarDef {
using FTypeExpr = IRFunctor<Expr (const NodeRef&, VarExpr)>;
using FTypeStmt = IRFunctor<Stmt (const NodeRef&, VarExpr)>;
static FTypeExpr& vtable_expr() { // NOLINT(*)
static FTypeExpr inst; return inst;
}
static FTypeStmt& vtable_stmt() { // NOLINT(*)
static FTypeStmt inst; return inst;
}
};
TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_expr)
.set_dispatch<Let>([](const Let* op, VarExpr var) {
std::shared_ptr<Let> x = std::make_shared<Let>(*op);
x->var = var;
return Expr(x);
});
TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_stmt)
.set_dispatch<LetStmt>([](const LetStmt* op, VarExpr var) {
std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*op);
x->var = var;
return Stmt(x);
})
.set_dispatch<For>([](const For* op, VarExpr var) {
std::shared_ptr<For> x = std::make_shared<For>(*op);
x->loop_var = var;
return Stmt(x);
});
class IRVerifySSA : public IRVisitor {
class IRVerifySSA final : public IRVisitor {
public:
bool is_ssa{true};
void Visit(const NodeRef& n) final {
if (!is_ssa) return;
static auto& fget_var_def = FGetVarDef::vtable();
if (fget_var_def.can_dispatch(n)) {
VarExpr v = fget_var_def(n);
if (defined_.count(v.get()) != 0) {
is_ssa = false; return;
} else {
defined_[v.get()] = 1;
IRVisitor::Visit(n);
}
void Visit_(const Let* op) final {
MarkDef(op->var.get());
IRVisitor::Visit_(op);
}
IRVisitor::Visit(n);
void Visit_(const LetStmt* op) final {
MarkDef(op->var.get());
IRVisitor::Visit_(op);
}
void Visit_(const For* op) final {
MarkDef(op->loop_var.get());
IRVisitor::Visit_(op);
}
void Visit_(const Allocate* op) final {
MarkDef(op->buffer_var.get());
IRVisitor::Visit_(op);
}
private:
void MarkDef(const Variable* v) {
if (defined_.count(v) != 0) {
is_ssa = false; return;
} else {
defined_[v] = 1;
}
}
std::unordered_map<const Variable*, int> defined_;
};
class IRConvertSSA : public IRMutator {
class IRConvertSSA final : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static auto& fget_var_def = FGetVarDef::vtable();
static auto& fset_var_def = FSetVarDef::vtable_expr();
if (fget_var_def.can_dispatch(expr)) {
VarExpr v = fget_var_def(expr);
VarExpr new_var = v;
if (defined_.count(v.get()) != 0) {
CHECK(expr.as<Allocate>() == nullptr)
<< "One allocation in two places, cannot rename buffer in allocate";
new_var = Variable::make(v->type, v->name_hint);
Expr Mutate_(const Variable* op, const Expr& e) final {
if (scope_.count(op)) {
return scope_[op].back();
} else {
defined_.insert(v.get());
return e;
}
}
Expr Mutate_(const Let* op, const Expr& e) final {
const VarExpr& v = op->var;
if (defined_.count(v.get())) {
Expr value = IRMutator::Mutate(op->value);
VarExpr new_var = Variable::make(v.type(), v->name_hint);
scope_[v.get()].push_back(new_var);
Expr new_expr = IRMutator::Mutate(expr);
Expr body = IRMutator::Mutate(op->body);
scope_[v.get()].pop_back();
if (!new_var.same_as(v)) {
return fset_var_def(new_expr, new_var);
return Let::make(new_var, value, body);
} else {
return new_expr;
defined_.insert(v.get());
return IRMutator::Mutate_(op, e);
}
}
} else if (expr.as<Variable>()) {
const Variable* v = expr.as<Variable>();
if (scope_.count(v) != 0) {
return scope_[v].back();
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
if (scope_.count(op->buffer_var.get())) {
return Load::make(
op->type, scope_[op->buffer_var.get()].back(),
op->index, op->predicate);
} else {
return expr;
}
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
if (scope_.count(op->buffer_var.get())) {
return Store::make(
scope_[op->buffer_var.get()].back(), op->value,
op->index, op->predicate);
} else {
Expr e = IRMutator::Mutate(expr);
return e;
return stmt;
}
}
Stmt Mutate(Stmt stmt) final {
static auto& fget_var_def = FGetVarDef::vtable();
static auto& fset_var_def = FSetVarDef::vtable_stmt();
if (fget_var_def.can_dispatch(stmt)) {
VarExpr v = fget_var_def(stmt);
VarExpr new_var = v;
if (defined_.count(v.get()) != 0) {
new_var = Variable::make(v->type, v->name_hint);
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
const VarExpr& v = op->var;
if (defined_.count(v.get())) {
Expr value = IRMutator::Mutate(op->value);
VarExpr new_var = Variable::make(v.type(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt body = IRMutator::Mutate(op->body);
scope_[v.get()].pop_back();
return LetStmt::make(new_var, value, body);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
const VarExpr& v = op->loop_var;
if (defined_.count(v.get())) {
VarExpr new_var = Variable::make(v.type(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt new_stmt = IRMutator::Mutate(stmt);
Stmt stmt = IRMutator::Mutate_(op, s);
scope_[v.get()].pop_back();
if (!new_var.same_as(v)) {
return fset_var_def(new_stmt, new_var);
op = stmt.as<For>();
return For::make(
new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
const VarExpr& v = op->buffer_var;
if (defined_.count(v.get())) {
VarExpr new_var = Variable::make(v.type(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = IRMutator::Mutate_(op, s);
scope_[v.get()].pop_back();
op = stmt.as<Allocate>();
return Allocate::make(
new_var, op->type, op->extents, op->condition,
op->body, op->new_expr, op->free_function);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (const Variable* v = op->node.as<Variable>()) {
if (op->attr_key == attr::storage_scope) {
const Allocate* alloc = op->body.as<Allocate>();
if (alloc && op->node.same_as(alloc->buffer_var)) {
Stmt new_alloc = Mutate(op->body);
if (new_alloc.same_as(op->body)) return s;
alloc = new_alloc.as<Allocate>();
CHECK(alloc);
return AttrStmt::make(
alloc->buffer_var, op->attr_key, op->value, new_alloc);
}
}
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>();
if (scope_.count(v) && scope_[v].size() != 0) {
return AttrStmt::make(
scope_[v].back(), op->attr_key, op->value, op->body);
} else {
return new_stmt;
return stmt;
}
} else {
return IRMutator::Mutate(stmt);
return IRMutator::Mutate_(op, s);
}
}
......
import tvm
def test_lower_rfactor():
n = tvm.var("n")
m = tvm.var("m")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k")
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
s = tvm.create_schedule(B.op)
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
BF = s.rfactor(B, ki)
xo, xi = s[B].split(s[B].op.axis[0], factor=32)
s[B.op].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B.op].bind(xi, tvm.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x"))
s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
fapi = tvm.lower(s, [A, B])
if __name__ == "__main__":
test_lower_rfactor()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment