Commit 38f03f1f by tqchen

SSA Pass

parent 7e7c24e1
Subproject commit 4becbde67c8aa565941b02648cea90f50211f8dc
Subproject commit 24a7c0357a6a8db5db782d320aad7f706ebe8507
export CXX=g++
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
......
......@@ -27,7 +27,9 @@ using Halide::abs;
using Halide::select;
using Halide::Expr;
using Halide::VarExpr;
using Halide::IR::FunctionRef;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt;
......
......@@ -28,7 +28,7 @@ bool VerifySSA(const IRNodeRef& ir);
* \param stmt The source statement to be converted.
* \return The converted form.
*/
Stmt ConvertSSA(const Stmt stmt);
Stmt ConvertSSA(const Stmt& stmt);
/*!
* \brief inline all calls of f in stmt.
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
namespace tvm {
namespace ir {
namespace {
struct SetVarDef {
// get var definition from node
using FType = IRFunctor<const Variable*(const IRNodeRef&)>;
static FGetVarDef& vtable_get_var_def() { // NOLINT(*)
static FGetVarDef inst; return inst;
}
static FSetVarExpr& vtable_set_var_expr() { // NOLINT(*)
static FSetVarExpr inst; return inst;
}
static FSetVarStmt& vtable_set_var_expr() { // NOLINT(*)
static FSetVarStmt inst; return inst;
}
};
// return a new node to
using FSetVarExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
// return a new node to
using FSetVarStmt = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
inline const Variable* GetVarDef(const IRNodeRef& n) {
if (n.as<Let>()) {
return n.as<Let>()->var.get();
} else if (n.as<LetStmt>()) {
return n.as<LetStmt>()->var.get();
} else if (n.as<For>()) {
return n.as<For>()->loop_var.get();
} else if (n.as<Allocate>()) {
return n.as<Allocate>()->buffer_var.get();
} else {
return nullptr;
}
}
inline Expr ResetVar(const Expr& n, VarExpr var) {
if (n.as<Let>()) {
std::shared_ptr<Let> x = std::make_shared<Let>(*n.as<Let>());
x->var = var;
return Expr(x);
} else if (n.as<Allocate>()) {
}
}
inline Stmt ResetVarDef(const Stmt& n, VarExpr var) {
if (n.as<LetStmt>()) {
std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*n.as<Let>());
x->var = var;
return Expr(x);
} else if (n.as<For>()) {
std::shared_ptr<For> x = std::make_shared<For>(*n.as<Let>());
x->loop_var = var;
return Expr(x);
} else {
LOG(FATAL) << "not reached";
}
}
class IRVerifySSA : public IRVisitor {
public:
bool is_ssa{true};
std::unordered_set<const Variable*> defined;
void Visit(const IRNodeRef& n) final {
if (!is_ssa) return;
const Variable* v = GetVarDef(n);
if (v != nullptr) {
if (defined.count(v) != 0) {
is_ssa = false; return;
} else {
defined.insert(v);
}
}
IRVisitor::Visit(n);
}
};
class IRConvertSSA : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static const auto& f = IRConvertSSA::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const auto& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
using FConvertExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
using FConvertStmt = IRFunctor<Stmt(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
std::unordered_map<const Variable*, std::vector<VarExpr> > scope;
std::unordered_set<const Variable*> defined;
};
temple<>
TVM_STATIC_IR_FUNCTOR(IRConvertSSA, vtable_expr)
.set_dispatch<Let>([](const Let* op, const Expr& e, IRConvertSSA* m) {
VarExpr var = op->var;
if (m->defined.count(var.get()) != 0) {
var = Variable::make(var->type, var->name_hint);
}
// insert scope before recursion.
m->scope[var.get()].push_back(var);
Expr new_expr = Mutate(e);
m->scope[var.get()].pop_back();
if (!var.same_as(op->var)) {
std::shared_ptr<Let> x = std::make_shared<Let>(*new_expr.as<Let>());
x->var = var;
return Expr(x);
} else {
return new_expr;
}
});
} // namespace
bool VerifySSA(const IRNodeRef& ir) {
IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
}
} // namespace ir
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* SSA related checks and pass.
* \file ssa.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace ir {
namespace {
// global functor to get var definition from
struct FGetVarDef {
using FType = IRFunctor<VarExpr (const IRNodeRef&)>;
static FType& vtable() { // NOLINT(*)
static FType inst; return inst;
}
};
TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable)
.set_dispatch<Let>([](const Let* op) {
return op->var;
})
.set_dispatch<LetStmt>([](const LetStmt* op) {
return op->var;
})
.set_dispatch<For>([](const For* op) {
return op->loop_var;
})
.set_dispatch<Allocate>([](const Allocate* op) {
return op->buffer_var;
});
struct FSetVarDef {
using FTypeExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
using FTypeStmt = IRFunctor<Stmt (const IRNodeRef&, VarExpr)>;
static FTypeExpr& vtable_expr() { // NOLINT(*)
static FTypeExpr inst; return inst;
}
static FTypeStmt& vtable_stmt() { // NOLINT(*)
static FTypeStmt inst; return inst;
}
};
TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_expr)
.set_dispatch<Let>([](const Let* op, VarExpr var) {
std::shared_ptr<Let> x = std::make_shared<Let>(*op);
x->var = var;
return Expr(x);
});
TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_stmt)
.set_dispatch<LetStmt>([](const LetStmt* op, VarExpr var) {
std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*op);
x->var = var;
return Stmt(x);
})
.set_dispatch<For>([](const For* op, VarExpr var) {
std::shared_ptr<For> x = std::make_shared<For>(*op);
x->loop_var = var;
return Stmt(x);
});
class IRVerifySSA : public IRVisitor {
public:
bool is_ssa{true};
void Visit(const IRNodeRef& n) final {
if (!is_ssa) return;
static auto& fget_var_def = FGetVarDef::vtable();
if (fget_var_def.can_dispatch(n)) {
VarExpr v = fget_var_def(n);
if (defined_.count(v.get()) != 0) {
is_ssa = false; return;
} else {
defined_[v.get()] = 1;
}
}
IRVisitor::Visit(n);
}
private:
std::unordered_map<const Variable*, int> defined_;
};
class IRConvertSSA : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static auto& fget_var_def = FGetVarDef::vtable();
static auto& fset_var_def = FSetVarDef::vtable_expr();
if (fget_var_def.can_dispatch(expr)) {
VarExpr v = fget_var_def(expr);
VarExpr new_var = v;
if (defined_.count(v.get()) != 0) {
CHECK(expr.as<Allocate>() == nullptr)
<< "One allocation in two places, cannot rename buffer in allocate";
new_var = Variable::make(v->type, v->name_hint);
} else {
defined_.insert(v.get());
}
scope_[v.get()].push_back(new_var);
Expr new_expr = IRMutator::Mutate(expr);
scope_[v.get()].pop_back();
if (!new_var.same_as(v)) {
return fset_var_def(new_expr, new_var);
} else {
return new_expr;
}
} else if (expr.as<Variable>()) {
const Variable* v = expr.as<Variable>();
if (scope_.count(v) != 0) {
return scope_[v].back();
} else {
return expr;
}
} else {
Expr e = IRMutator::Mutate(expr);
return e;
}
}
Stmt Mutate(Stmt stmt) final {
static auto& fget_var_def = FGetVarDef::vtable();
static auto& fset_var_def = FSetVarDef::vtable_stmt();
if (fget_var_def.can_dispatch(stmt)) {
VarExpr v = fget_var_def(stmt);
VarExpr new_var = v;
if (defined_.count(v.get()) != 0) {
new_var = Variable::make(v->type, v->name_hint);
} else {
defined_.insert(v.get());
}
scope_[v.get()].push_back(new_var);
Stmt new_stmt = IRMutator::Mutate(stmt);
scope_[v.get()].pop_back();
if (!new_var.same_as(v)) {
return fset_var_def(new_stmt, new_var);
} else {
return new_stmt;
}
} else {
return IRMutator::Mutate(stmt);
}
}
private:
std::unordered_map<const Variable*, std::vector<VarExpr> > scope_;
std::unordered_set<const Variable*> defined_;
};
} // namespace
bool VerifySSA(const IRNodeRef& ir) {
IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
}
Stmt ConvertSSA(const Stmt& stmt) {
return IRConvertSSA().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -3,23 +3,25 @@
#include <tvm/tvm.h>
#include <tvm/ir_pass.h>
TEST(IRPass, Substitute) {
TEST(IRSSA, Convert) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
Expr let = Let::make(x, 1, x + 1);
auto z = let + let;
CHECK(!ir::VerifySSA(z));
auto z_ssa = ir::ConvertSSA(Evaluate::make(z));
CHECK(ir::VerifySSA(z_ssa));
}
TEST(IRSSA, Basic) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = ir::Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = ir::Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
CHECK(ir::VerifySSA(z));
}
int main(int argc, char ** argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment