Commit 7e7c24e1 by tqchen

temp checkin

parent 8e04361c
......@@ -27,6 +27,7 @@ using Halide::abs;
using Halide::select;
using Halide::Expr;
using Halide::VarExpr;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt;
......
......@@ -45,6 +45,49 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Max = "Max";
static constexpr const char* Min = "Min";
};
// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
using Halide::Internal::FloatImm;
using Halide::Internal::StringImm;
using Halide::Internal::Cast;
using Halide::Internal::Variable;
using Halide::Internal::Add;
using Halide::Internal::Sub;
using Halide::Internal::Mul;
using Halide::Internal::Div;
using Halide::Internal::Mod;
using Halide::Internal::Min;
using Halide::Internal::Max;
using Halide::Internal::EQ;
using Halide::Internal::NE;
using Halide::Internal::LT;
using Halide::Internal::LE;
using Halide::Internal::GT;
using Halide::Internal::GE;
using Halide::Internal::And;
using Halide::Internal::Or;
using Halide::Internal::Not;
using Halide::Internal::Select;
using Halide::Internal::Load;
using Halide::Internal::Ramp;
using Halide::Internal::Broadcast;
using Halide::Internal::Call;
using Halide::Internal::Let;
using Halide::Internal::LetStmt;
using Halide::Internal::AssertStmt;
using Halide::Internal::ProducerConsumer;
using Halide::Internal::For;
using Halide::Internal::Store;
using Halide::Internal::Provide;
using Halide::Internal::Allocate;
using Halide::Internal::Free;
using Halide::Internal::Realize;
using Halide::Internal::Block;
using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate;
} // namespace ir
} // namespace tvm
......
......@@ -14,11 +14,35 @@ namespace tvm {
namespace ir {
/*!
* \brief Substitute occurance of IRNode in expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
*
* \param ir The root of the IR DAG.
* \return Whether IR is in SSA form.
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
bool VerifySSA(const IRNodeRef& ir);
/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
* \return The converted form.
*/
Stmt ConvertSSA(const Stmt stmt);
/*!
* \brief inline all calls of f in stmt.
*
* \param f The function reference to be inlined
* \param args The arguments variable of the function.
* \param body The defintion body of the function.
* \param stmt The statement to apply inline optimization.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt InlineSSA(FunctionRef f, const std::vector<Var>& args, Expr body, Stmt stmt);
} // namespace ir
} // namespace tvm
......
......@@ -10,29 +10,128 @@
namespace tvm {
namespace ir {
namespace {
// visitor to implement apply
class IRSubstitute : public IRMutator {
struct SetVarDef {
// get var definition from node
using FType = IRFunctor<const Variable*(const IRNodeRef&)>;
static FGetVarDef& vtable_get_var_def() { // NOLINT(*)
static FGetVarDef inst; return inst;
}
static FSetVarExpr& vtable_set_var_expr() { // NOLINT(*)
static FSetVarExpr inst; return inst;
}
static FSetVarStmt& vtable_set_var_expr() { // NOLINT(*)
static FSetVarStmt inst; return inst;
}
};
// return a new node to
using FSetVarExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
// return a new node to
using FSetVarStmt = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
inline const Variable* GetVarDef(const IRNodeRef& n) {
if (n.as<Let>()) {
return n.as<Let>()->var.get();
} else if (n.as<LetStmt>()) {
return n.as<LetStmt>()->var.get();
} else if (n.as<For>()) {
return n.as<For>()->loop_var.get();
} else if (n.as<Allocate>()) {
return n.as<Allocate>()->buffer_var.get();
} else {
return nullptr;
}
}
inline Expr ResetVar(const Expr& n, VarExpr var) {
if (n.as<Let>()) {
std::shared_ptr<Let> x = std::make_shared<Let>(*n.as<Let>());
x->var = var;
return Expr(x);
} else if (n.as<Allocate>()) {
}
}
inline Stmt ResetVarDef(const Stmt& n, VarExpr var) {
if (n.as<LetStmt>()) {
std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*n.as<Let>());
x->var = var;
return Expr(x);
} else if (n.as<For>()) {
std::shared_ptr<For> x = std::make_shared<For>(*n.as<Let>());
x->loop_var = var;
return Expr(x);
} else {
LOG(FATAL) << "not reached";
}
}
class IRVerifySSA : public IRVisitor {
public:
Expr Mutate(Expr expr) final {
const IRNode* v = expr.get();
bool is_ssa{true};
std::unordered_set<const Variable*> defined;
void Visit(const IRNodeRef& n) final {
if (!is_ssa) return;
const Variable* v = GetVarDef(n);
if (v != nullptr) {
auto it = replacements_.find(v);
if (it != replacements_.end()) {
return it->second;
if (defined.count(v) != 0) {
is_ssa = false; return;
} else {
defined.insert(v);
}
}
return IRMutator::Mutate(expr);
IRVisitor::Visit(n);
}
explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements)
: replacements_(replacements) {}
};
private:
const std::unordered_map<const IRNode*, Expr>& replacements_;
class IRConvertSSA : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static const auto& f = IRConvertSSA::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const auto& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
using FConvertExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
using FConvertStmt = IRFunctor<Stmt(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
std::unordered_map<const Variable*, std::vector<VarExpr> > scope;
std::unordered_set<const Variable*> defined;
};
temple<>
TVM_STATIC_IR_FUNCTOR(IRConvertSSA, vtable_expr)
.set_dispatch<Let>([](const Let* op, const Expr& e, IRConvertSSA* m) {
VarExpr var = op->var;
if (m->defined.count(var.get()) != 0) {
var = Variable::make(var->type, var->name_hint);
}
// insert scope before recursion.
m->scope[var.get()].push_back(var);
Expr new_expr = Mutate(e);
m->scope[var.get()].pop_back();
if (!var.same_as(op->var)) {
std::shared_ptr<Let> x = std::make_shared<Let>(*new_expr.as<Let>());
x->var = var;
return Expr(x);
} else {
return new_expr;
}
});
} // namespace
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) {
return IRSubstitute(replacements).Mutate(expr);
bool VerifySSA(const IRNodeRef& ir) {
IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
}
} // namespace ir
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment