Commit 8e04361c by tqchen

Refactor IR Pass

parent ff6b8d82
Subproject commit 89b7939957d66a37dd6083ad6b09a5644e73fd8b Subproject commit 4becbde67c8aa565941b02648cea90f50211f8dc
...@@ -27,6 +27,7 @@ using Halide::abs; ...@@ -27,6 +27,7 @@ using Halide::abs;
using Halide::select; using Halide::select;
using Halide::Expr; using Halide::Expr;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt; using Halide::Internal::Stmt;
class Var : public Halide::VarExpr { class Var : public Halide::VarExpr {
......
...@@ -29,7 +29,7 @@ class IRMutator { ...@@ -29,7 +29,7 @@ class IRMutator {
* \brief mutate expression * \brief mutate expression
* \return the mutated expr * \return the mutated expr
*/ */
virtual Expr mutate(Expr expr) { virtual Expr Mutate(Expr expr) {
static const FMutateExpr& f = vtable_expr(); static const FMutateExpr& f = vtable_expr();
return f(expr, expr, this); return f(expr, expr, this);
} }
...@@ -37,7 +37,7 @@ class IRMutator { ...@@ -37,7 +37,7 @@ class IRMutator {
* \brief mutate expression * \brief mutate expression
* \return the mutated stmt * \return the mutated stmt
*/ */
virtual Stmt mutate(Stmt stmt) { virtual Stmt Mutate(Stmt stmt) {
static const FMutateStmt& f = vtable_stmt(); static const FMutateStmt& f = vtable_stmt();
return f(stmt, stmt, this); return f(stmt, stmt, this);
} }
...@@ -58,28 +58,21 @@ class IRMutator { ...@@ -58,28 +58,21 @@ class IRMutator {
*/ */
class IRMutatorExample : public IRMutator { class IRMutatorExample : public IRMutator {
public: public:
Expr mutate(Expr expr) final { Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr(); static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ? return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr)); f(expr, expr, this) : IRMutator::Mutate(expr));
} }
Stmt mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt(); static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ? return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::mutate(stmt)); f(stmt, stmt, this) : IRMutator::Mutate(stmt));
} }
// to be implemented by child class // to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*) static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*) static FMutateStmt& vtable_stmt(); // NOLINT(*)
}; };
/*!
* \brief Substitute occurance of IRNode to be expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_IR_MUTATOR_H_ #endif // TVM_IR_MUTATOR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.h
* \brief Collection of IR pass functions and visit functions
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <tvm/ir_node.h>
#include <unordered_map>
#include "./expr.h"
namespace tvm {
namespace ir {
/*!
* \brief Substitute occurance of IRNode in expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
} // namespace ir
} // namespace tvm
#endif // TVM_IR_PASS_H_
...@@ -24,7 +24,7 @@ class IRVisitor { ...@@ -24,7 +24,7 @@ class IRVisitor {
/*! /*!
* \brief recursively visit an IR node * \brief recursively visit an IR node
*/ */
virtual void visit(const IRNodeRef& node) { virtual void Visit(const IRNodeRef& node) {
static const FVisit& f = vtable(); static const FVisit& f = vtable();
if (node.defined()) f(node, this); if (node.defined()) f(node, this);
} }
......
...@@ -101,7 +101,7 @@ class Tensor : public FunctionRef { ...@@ -101,7 +101,7 @@ class Tensor : public FunctionRef {
}; };
/*! \brief Node to represent a tensor */ /*! \brief Node to represent a tensor */
class TensorNode : public Node { class TensorNode : public FunctionBaseNode {
public: public:
/*! \brief The shape of the tensor */ /*! \brief The shape of the tensor */
Array<Expr> shape; Array<Expr> shape;
...@@ -125,6 +125,12 @@ class TensorNode : public Node { ...@@ -125,6 +125,12 @@ class TensorNode : public Node {
v->Visit("dim_var", &dim_var); v->Visit("dim_var", &dim_var);
v->Visit("source", &source); v->Visit("source", &source);
} }
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape, static Tensor make(Array<Expr> shape,
std::string name, std::string name,
Type dtype, Type dtype,
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
namespace tvm {
namespace ir {
namespace {
// visitor to implement apply
class IRSubstitute : public IRMutator {
public:
Expr Mutate(Expr expr) final {
const IRNode* v = expr.get();
if (v != nullptr) {
auto it = replacements_.find(v);
if (it != replacements_.end()) {
return it->second;
}
}
return IRMutator::Mutate(expr);
}
explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements)
: replacements_(replacements) {}
private:
const std::unordered_map<const IRNode*, Expr>& replacements_;
};
} // namespace
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) {
return IRSubstitute(replacements).Mutate(expr);
}
} // namespace ir
} // namespace tvm
...@@ -14,10 +14,10 @@ class IRApplyVisit : public IRVisitor { ...@@ -14,10 +14,10 @@ class IRApplyVisit : public IRVisitor {
public: public:
explicit IRApplyVisit(std::function<void(const IRNodeRef&)> f) : f_(f) {} explicit IRApplyVisit(std::function<void(const IRNodeRef&)> f) : f_(f) {}
void visit(const IRNodeRef& node) final { void Visit(const IRNodeRef& node) final {
if (visited_.count(node.get()) != 0) return; if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get()); visited_.insert(node.get());
IRVisitor::visit(node); IRVisitor::Visit(node);
f_(node); f_(node);
} }
...@@ -25,18 +25,18 @@ class IRApplyVisit : public IRVisitor { ...@@ -25,18 +25,18 @@ class IRApplyVisit : public IRVisitor {
std::function<void(const IRNodeRef&)> f_; std::function<void(const IRNodeRef&)> f_;
std::unordered_set<const Node*> visited_; std::unordered_set<const Node*> visited_;
}; };
} // namespace } // namespace
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
}
IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*) IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst; static FVisit inst; return inst;
} }
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) {
IRApplyVisit v(fvisit);
v.visit(node);
}
// namespace to register the functors. // namespace to register the functors.
namespace { namespace {
...@@ -47,22 +47,22 @@ void NoOp(const IRNodeRef& n, IRVisitor* v) { ...@@ -47,22 +47,22 @@ void NoOp(const IRNodeRef& n, IRVisitor* v) {
inline void VisitArray(Array<Expr> arr, IRVisitor* v) { inline void VisitArray(Array<Expr> arr, IRVisitor* v) {
for (size_t i = 0; i < arr.size(); i++) { for (size_t i = 0; i < arr.size(); i++) {
v->visit(arr[i]); v->Visit(arr[i]);
} }
} }
inline void VisitRDom(RDomain rdom, IRVisitor* v) { inline void VisitRDom(RDomain rdom, IRVisitor* v) {
for (size_t i = 0; i < rdom->domain.size(); i++) { for (size_t i = 0; i < rdom->domain.size(); i++) {
Range r = rdom->domain[i]; Range r = rdom->domain[i];
v->visit(r->min); v->Visit(r->min);
v->visit(r->extent); v->Visit(r->extent);
} }
} }
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) { .set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v); VisitRDom(op->rdom, v);
v->visit(op->source); v->Visit(op->source);
}); });
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
...@@ -74,14 +74,14 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -74,14 +74,14 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Cast>([](const Cast* op, IRVisitor* v) { .set_dispatch<Cast>([](const Cast* op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
}); });
// binary operator // binary operator
template<typename T> template<typename T>
inline void Binary(const T* op, IRVisitor* v) { inline void Binary(const T* op, IRVisitor* v) {
v->visit(op->a); v->Visit(op->a);
v->visit(op->b); v->Visit(op->b);
} }
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
...@@ -103,51 +103,51 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -103,51 +103,51 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Not>([](const Not* op, IRVisitor* v) { .set_dispatch<Not>([](const Not* op, IRVisitor* v) {
v->visit(op->a); v->Visit(op->a);
}) })
.set_dispatch<Select>([](const Select *op, IRVisitor* v) { .set_dispatch<Select>([](const Select *op, IRVisitor* v) {
v->visit(op->condition); v->Visit(op->condition);
v->visit(op->true_value); v->Visit(op->true_value);
v->visit(op->false_value); v->Visit(op->false_value);
}) })
.set_dispatch<Load>([](const Load *op, IRVisitor* v) { .set_dispatch<Load>([](const Load *op, IRVisitor* v) {
v->visit(op->index); v->Visit(op->index);
}) })
.set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) { .set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
v->visit(op->base); v->Visit(op->base);
v->visit(op->stride); v->Visit(op->stride);
}) })
.set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) { .set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
}) })
.set_dispatch<Call>([](const Call *op, IRVisitor* v) { .set_dispatch<Call>([](const Call *op, IRVisitor* v) {
VisitArray(op->args, v); VisitArray(op->args, v);
}) })
.set_dispatch<Let>([](const Let *op, IRVisitor* v) { .set_dispatch<Let>([](const Let *op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
v->visit(op->body); v->Visit(op->body);
}); });
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<LetStmt>([](const LetStmt *op, IRVisitor* v) { .set_dispatch<LetStmt>([](const LetStmt *op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
v->visit(op->body); v->Visit(op->body);
}) })
.set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) { .set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) {
v->visit(op->condition); v->Visit(op->condition);
v->visit(op->message); v->Visit(op->message);
}) })
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) { .set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) {
v->visit(op->body); v->Visit(op->body);
}) })
.set_dispatch<For>([](const For *op, IRVisitor* v) { .set_dispatch<For>([](const For *op, IRVisitor* v) {
v->visit(op->min); v->Visit(op->min);
v->visit(op->extent); v->Visit(op->extent);
v->visit(op->body); v->Visit(op->body);
}) })
.set_dispatch<Store>([](const Store *op, IRVisitor* v) { .set_dispatch<Store>([](const Store *op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
v->visit(op->index); v->Visit(op->index);
}) })
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) { .set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v); VisitArray(op->args, v);
...@@ -155,36 +155,36 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -155,36 +155,36 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
}) })
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) { .set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) { for (size_t i = 0; i < op->extents.size(); i++) {
v->visit(op->extents[i]); v->Visit(op->extents[i]);
} }
v->visit(op->body); v->Visit(op->body);
v->visit(op->condition); v->Visit(op->condition);
if (op->new_expr.defined()) { if (op->new_expr.defined()) {
v->visit(op->new_expr); v->Visit(op->new_expr);
} }
}) })
.set_dispatch<Free>(NoOp) .set_dispatch<Free>(NoOp)
.set_dispatch<Realize>([](const Realize *op, IRVisitor* v) { .set_dispatch<Realize>([](const Realize *op, IRVisitor* v) {
// Mutate the bounds // Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) { for (size_t i = 0; i < op->bounds.size(); i++) {
v->visit(op->bounds[i]->min); v->Visit(op->bounds[i]->min);
v->visit(op->bounds[i]->extent); v->Visit(op->bounds[i]->extent);
} }
v->visit(op->body); v->Visit(op->body);
v->visit(op->condition); v->Visit(op->condition);
}) })
.set_dispatch<Block>([](const Block *op, IRVisitor* v) { .set_dispatch<Block>([](const Block *op, IRVisitor* v) {
v->visit(op->first); v->Visit(op->first);
v->visit(op->rest); v->Visit(op->rest);
}) })
.set_dispatch<IfThenElse>([](const IfThenElse *op, IRVisitor* v) { .set_dispatch<IfThenElse>([](const IfThenElse *op, IRVisitor* v) {
v->visit(op->condition); v->Visit(op->condition);
v->visit(op->then_case); v->Visit(op->then_case);
v->visit(op->else_case); v->Visit(op->else_case);
}) })
.set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) { .set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
}); });
} // namespace } // namespace
......
...@@ -13,10 +13,10 @@ class IRVar2Const : public IRMutator { ...@@ -13,10 +13,10 @@ class IRVar2Const : public IRMutator {
public: public:
VarExpr var; VarExpr var;
int int_val; int int_val;
Expr mutate(Expr expr) final { Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr(); static const FMutateExpr& f = IRVar2Const::vtable_expr();
return (f.can_dispatch(expr) ? return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr)); f(expr, expr, this) : IRMutator::Mutate(expr));
} }
static FMutateExpr &vtable_expr(); static FMutateExpr &vtable_expr();
}; };
...@@ -46,31 +46,12 @@ TEST(IRMutator, Basic) { ...@@ -46,31 +46,12 @@ TEST(IRMutator, Basic) {
IRVar2Const mu; IRVar2Const mu;
mu.var = y; mu.var = y;
mu.int_val = 10; mu.int_val = 10;
auto zz = mu.mutate(z); auto zz = mu.Mutate(z);
std::ostringstream os; std::ostringstream os;
os << zz; os << zz;
CHECK(os.str() == "(x + 10)"); CHECK(os.str() == "(x + 10)");
} }
TEST(IRMutator, Substitute) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_pass.h>
TEST(IRPass, Substitute) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = ir::Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = ir::Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
TEST(IRVisitor, CountVar) { TEST(IRVisitor, CountVar) {
using namespace Halide::Internal; using namespace Halide::Internal;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment