Commit 7e7c24e1 by tqchen

temp checkin

parent 8e04361c
...@@ -27,6 +27,7 @@ using Halide::abs; ...@@ -27,6 +27,7 @@ using Halide::abs;
using Halide::select; using Halide::select;
using Halide::Expr; using Halide::Expr;
using Halide::VarExpr;
using Halide::IR::FunctionBaseNode; using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt; using Halide::Internal::Stmt;
......
...@@ -45,6 +45,49 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -45,6 +45,49 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Max = "Max"; static constexpr const char* Max = "Max";
static constexpr const char* Min = "Min"; static constexpr const char* Min = "Min";
}; };
// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
using Halide::Internal::FloatImm;
using Halide::Internal::StringImm;
using Halide::Internal::Cast;
using Halide::Internal::Variable;
using Halide::Internal::Add;
using Halide::Internal::Sub;
using Halide::Internal::Mul;
using Halide::Internal::Div;
using Halide::Internal::Mod;
using Halide::Internal::Min;
using Halide::Internal::Max;
using Halide::Internal::EQ;
using Halide::Internal::NE;
using Halide::Internal::LT;
using Halide::Internal::LE;
using Halide::Internal::GT;
using Halide::Internal::GE;
using Halide::Internal::And;
using Halide::Internal::Or;
using Halide::Internal::Not;
using Halide::Internal::Select;
using Halide::Internal::Load;
using Halide::Internal::Ramp;
using Halide::Internal::Broadcast;
using Halide::Internal::Call;
using Halide::Internal::Let;
using Halide::Internal::LetStmt;
using Halide::Internal::AssertStmt;
using Halide::Internal::ProducerConsumer;
using Halide::Internal::For;
using Halide::Internal::Store;
using Halide::Internal::Provide;
using Halide::Internal::Allocate;
using Halide::Internal::Free;
using Halide::Internal::Realize;
using Halide::Internal::Block;
using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate;
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -14,11 +14,35 @@ namespace tvm { ...@@ -14,11 +14,35 @@ namespace tvm {
namespace ir { namespace ir {
/*! /*!
* \brief Substitute occurance of IRNode in expr * \brief verifies whether the IR stmt or Expr is in SSA form.
* \param replacements The replacement rule of substitution * That is: each VarExpr is defined and assigned once(in Let/For)
* \param expr The expression to be substituted. *
* \param ir The root of the IR DAG.
* \return Whether IR is in SSA form.
* \note All the passes in this file uses SSA form and outputs SSA form.
*/ */
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr); bool VerifySSA(const IRNodeRef& ir);
/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
* \return The converted form.
*/
Stmt ConvertSSA(const Stmt stmt);
/*!
* \brief inline all calls of f in stmt.
*
* \param f The function reference to be inlined
* \param args The arguments variable of the function.
* \param body The defintion body of the function.
* \param stmt The statement to apply inline optimization.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt InlineSSA(FunctionRef f, const std::vector<Var>& args, Expr body, Stmt stmt);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -10,29 +10,128 @@ ...@@ -10,29 +10,128 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
namespace { namespace {
// visitor to implement apply
class IRSubstitute : public IRMutator { struct SetVarDef {
// get var definition from node
using FType = IRFunctor<const Variable*(const IRNodeRef&)>;
static FGetVarDef& vtable_get_var_def() { // NOLINT(*)
static FGetVarDef inst; return inst;
}
static FSetVarExpr& vtable_set_var_expr() { // NOLINT(*)
static FSetVarExpr inst; return inst;
}
static FSetVarStmt& vtable_set_var_expr() { // NOLINT(*)
static FSetVarStmt inst; return inst;
}
};
// return a new node to
using FSetVarExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
// return a new node to
using FSetVarStmt = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
inline const Variable* GetVarDef(const IRNodeRef& n) {
if (n.as<Let>()) {
return n.as<Let>()->var.get();
} else if (n.as<LetStmt>()) {
return n.as<LetStmt>()->var.get();
} else if (n.as<For>()) {
return n.as<For>()->loop_var.get();
} else if (n.as<Allocate>()) {
return n.as<Allocate>()->buffer_var.get();
} else {
return nullptr;
}
}
inline Expr ResetVar(const Expr& n, VarExpr var) {
if (n.as<Let>()) {
std::shared_ptr<Let> x = std::make_shared<Let>(*n.as<Let>());
x->var = var;
return Expr(x);
} else if (n.as<Allocate>()) {
}
}
inline Stmt ResetVarDef(const Stmt& n, VarExpr var) {
if (n.as<LetStmt>()) {
std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*n.as<Let>());
x->var = var;
return Expr(x);
} else if (n.as<For>()) {
std::shared_ptr<For> x = std::make_shared<For>(*n.as<Let>());
x->loop_var = var;
return Expr(x);
} else {
LOG(FATAL) << "not reached";
}
}
class IRVerifySSA : public IRVisitor {
public: public:
Expr Mutate(Expr expr) final { bool is_ssa{true};
const IRNode* v = expr.get(); std::unordered_set<const Variable*> defined;
void Visit(const IRNodeRef& n) final {
if (!is_ssa) return;
const Variable* v = GetVarDef(n);
if (v != nullptr) { if (v != nullptr) {
auto it = replacements_.find(v); if (defined.count(v) != 0) {
if (it != replacements_.end()) { is_ssa = false; return;
return it->second; } else {
defined.insert(v);
} }
} }
return IRMutator::Mutate(expr); IRVisitor::Visit(n);
} }
explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements) };
: replacements_(replacements) {}
private: class IRConvertSSA : public IRMutator {
const std::unordered_map<const IRNode*, Expr>& replacements_; public:
Expr Mutate(Expr expr) final {
static const auto& f = IRConvertSSA::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const auto& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
using FConvertExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
using FConvertStmt = IRFunctor<Stmt(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
std::unordered_map<const Variable*, std::vector<VarExpr> > scope;
std::unordered_set<const Variable*> defined;
}; };
temple<>
TVM_STATIC_IR_FUNCTOR(IRConvertSSA, vtable_expr)
.set_dispatch<Let>([](const Let* op, const Expr& e, IRConvertSSA* m) {
VarExpr var = op->var;
if (m->defined.count(var.get()) != 0) {
var = Variable::make(var->type, var->name_hint);
}
// insert scope before recursion.
m->scope[var.get()].push_back(var);
Expr new_expr = Mutate(e);
m->scope[var.get()].pop_back();
if (!var.same_as(op->var)) {
std::shared_ptr<Let> x = std::make_shared<Let>(*new_expr.as<Let>());
x->var = var;
return Expr(x);
} else {
return new_expr;
}
});
} // namespace } // namespace
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) { bool VerifySSA(const IRNodeRef& ir) {
return IRSubstitute(replacements).Mutate(expr); IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
} }
} // namespace ir } // namespace ir
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment