Commit 8e04361c by tqchen

Refactor IR Pass

parent ff6b8d82
Subproject commit 89b7939957d66a37dd6083ad6b09a5644e73fd8b Subproject commit 4becbde67c8aa565941b02648cea90f50211f8dc
...@@ -27,6 +27,7 @@ using Halide::abs; ...@@ -27,6 +27,7 @@ using Halide::abs;
using Halide::select; using Halide::select;
using Halide::Expr; using Halide::Expr;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt; using Halide::Internal::Stmt;
class Var : public Halide::VarExpr { class Var : public Halide::VarExpr {
......
...@@ -29,7 +29,7 @@ class IRMutator { ...@@ -29,7 +29,7 @@ class IRMutator {
* \brief mutate expression * \brief mutate expression
* \return the mutated expr * \return the mutated expr
*/ */
virtual Expr mutate(Expr expr) { virtual Expr Mutate(Expr expr) {
static const FMutateExpr& f = vtable_expr(); static const FMutateExpr& f = vtable_expr();
return f(expr, expr, this); return f(expr, expr, this);
} }
...@@ -37,7 +37,7 @@ class IRMutator { ...@@ -37,7 +37,7 @@ class IRMutator {
* \brief mutate expression * \brief mutate expression
* \return the mutated stmt * \return the mutated stmt
*/ */
virtual Stmt mutate(Stmt stmt) { virtual Stmt Mutate(Stmt stmt) {
static const FMutateStmt& f = vtable_stmt(); static const FMutateStmt& f = vtable_stmt();
return f(stmt, stmt, this); return f(stmt, stmt, this);
} }
...@@ -58,28 +58,21 @@ class IRMutator { ...@@ -58,28 +58,21 @@ class IRMutator {
*/ */
class IRMutatorExample : public IRMutator { class IRMutatorExample : public IRMutator {
public: public:
Expr mutate(Expr expr) final { Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr(); static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ? return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr)); f(expr, expr, this) : IRMutator::Mutate(expr));
} }
Stmt mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt(); static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ? return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::mutate(stmt)); f(stmt, stmt, this) : IRMutator::Mutate(stmt));
} }
// to be implemented by child class // to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*) static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*) static FMutateStmt& vtable_stmt(); // NOLINT(*)
}; };
/*!
* \brief Substitute occurance of IRNode to be expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_IR_MUTATOR_H_ #endif // TVM_IR_MUTATOR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.h
* \brief Collection of IR pass functions and visit functions
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <tvm/ir_node.h>
#include <unordered_map>
#include "./expr.h"
namespace tvm {
namespace ir {
/*!
* \brief Substitute occurance of IRNode in expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
} // namespace ir
} // namespace tvm
#endif // TVM_IR_PASS_H_
...@@ -24,7 +24,7 @@ class IRVisitor { ...@@ -24,7 +24,7 @@ class IRVisitor {
/*! /*!
* \brief recursively visit an IR node * \brief recursively visit an IR node
*/ */
virtual void visit(const IRNodeRef& node) { virtual void Visit(const IRNodeRef& node) {
static const FVisit& f = vtable(); static const FVisit& f = vtable();
if (node.defined()) f(node, this); if (node.defined()) f(node, this);
} }
......
...@@ -101,7 +101,7 @@ class Tensor : public FunctionRef { ...@@ -101,7 +101,7 @@ class Tensor : public FunctionRef {
}; };
/*! \brief Node to represent a tensor */ /*! \brief Node to represent a tensor */
class TensorNode : public Node { class TensorNode : public FunctionBaseNode {
public: public:
/*! \brief The shape of the tensor */ /*! \brief The shape of the tensor */
Array<Expr> shape; Array<Expr> shape;
...@@ -125,6 +125,12 @@ class TensorNode : public Node { ...@@ -125,6 +125,12 @@ class TensorNode : public Node {
v->Visit("dim_var", &dim_var); v->Visit("dim_var", &dim_var);
v->Visit("source", &source); v->Visit("source", &source);
} }
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape, static Tensor make(Array<Expr> shape,
std::string name, std::string name,
Type dtype, Type dtype,
......
...@@ -8,32 +8,6 @@ ...@@ -8,32 +8,6 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
namespace {
// visitor to implement apply
class IRSubstitute : public IRMutator {
public:
Expr mutate(Expr expr) final {
const IRNode* v = expr.get();
if (v != nullptr) {
auto it = replacements_.find(v);
if (it != replacements_.end()) {
return it->second;
}
}
return IRMutator::mutate(expr);
}
explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements)
: replacements_(replacements) {}
private:
const std::unordered_map<const IRNode*, Expr>& replacements_;
};
} // namespace
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) {
return IRSubstitute(replacements).mutate(expr);
}
IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst; static FMutateExpr inst; return inst;
} }
...@@ -57,7 +31,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) { ...@@ -57,7 +31,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
bool changed = false; bool changed = false;
for (size_t i = 0; i < arr.size(); i++) { for (size_t i = 0; i < arr.size(); i++) {
Expr old_elem = arr[i]; Expr old_elem = arr[i];
Expr new_elem = m->mutate(old_elem); Expr new_elem = m->Mutate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true; if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem; new_arr[i] = new_elem;
} }
...@@ -73,8 +47,8 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) { ...@@ -73,8 +47,8 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
bool changed = false; bool changed = false;
for (size_t i = 0; i < rdom->domain.size(); i++) { for (size_t i = 0; i < rdom->domain.size(); i++) {
Range r = rdom->domain[i]; Range r = rdom->domain[i];
Expr new_min = m->mutate(r->min); Expr new_min = m->Mutate(r->min);
Expr new_extent = m->mutate(r->extent); Expr new_extent = m->Mutate(r->extent);
if (!r->min.same_as(new_min)) changed = true; if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true; if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = Range::make_with_min_extent(new_min, new_extent); new_dom[i] = Range::make_with_min_extent(new_min, new_extent);
...@@ -89,7 +63,7 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) { ...@@ -89,7 +63,7 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) { .set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
RDomain new_rdom = MutateRDom(op->rdom, m); RDomain new_rdom = MutateRDom(op->rdom, m);
Expr new_source = m->mutate(op->source); Expr new_source = m->Mutate(op->source);
if (op->rdom.same_as(new_rdom) && if (op->rdom.same_as(new_rdom) &&
op->source.same_as(new_source)) { op->source.same_as(new_source)) {
return e; return e;
...@@ -107,7 +81,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -107,7 +81,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) { .set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value); Expr value = m->Mutate(op->value);
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return e; return e;
} else { } else {
...@@ -118,8 +92,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -118,8 +92,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
// binary operator // binary operator
template<typename T> template<typename T>
inline Expr Binary(const T* op, const Expr& e, IRMutator* m) { inline Expr Binary(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->mutate(op->a); Expr a = m->Mutate(op->a);
Expr b = m->mutate(op->b); Expr b = m->Mutate(op->b);
if (a.same_as(op->a) && if (a.same_as(op->a) &&
b.same_as(op->b)) { b.same_as(op->b)) {
return e; return e;
...@@ -147,7 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -147,7 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Not>([](const Not* op, const Expr& e, IRMutator* m) { .set_dispatch<Not>([](const Not* op, const Expr& e, IRMutator* m) {
Expr a = m->mutate(op->a); Expr a = m->Mutate(op->a);
if (a.same_as(op->a)) { if (a.same_as(op->a)) {
return e; return e;
} else { } else {
...@@ -155,9 +129,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -155,9 +129,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
} }
}) })
.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) { .set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) {
Expr cond = m->mutate(op->condition); Expr cond = m->Mutate(op->condition);
Expr t = m->mutate(op->true_value); Expr t = m->Mutate(op->true_value);
Expr f = m->mutate(op->false_value); Expr f = m->Mutate(op->false_value);
if (cond.same_as(op->condition) && if (cond.same_as(op->condition) &&
t.same_as(op->true_value) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) { f.same_as(op->false_value)) {
...@@ -167,7 +141,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -167,7 +141,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
} }
}) })
.set_dispatch<Load>([](const Load *op, const Expr& e, IRMutator* m) { .set_dispatch<Load>([](const Load *op, const Expr& e, IRMutator* m) {
Expr index = m->mutate(op->index); Expr index = m->Mutate(op->index);
if (index.same_as(op->index)) { if (index.same_as(op->index)) {
return e; return e;
} else { } else {
...@@ -175,8 +149,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -175,8 +149,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
} }
}) })
.set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) { .set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) {
Expr base = m->mutate(op->base); Expr base = m->Mutate(op->base);
Expr stride = m->mutate(op->stride); Expr stride = m->Mutate(op->stride);
if (base.same_as(op->base) && if (base.same_as(op->base) &&
stride.same_as(op->stride)) { stride.same_as(op->stride)) {
return e; return e;
...@@ -185,7 +159,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -185,7 +159,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
} }
}) })
.set_dispatch<Broadcast>([](const Broadcast *op, const Expr& e, IRMutator* m) { .set_dispatch<Broadcast>([](const Broadcast *op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value); Expr value = m->Mutate(op->value);
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return e; return e;
} else { } else {
...@@ -202,8 +176,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -202,8 +176,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
} }
}) })
.set_dispatch<Let>([](const Let *op, const Expr& e, IRMutator* m) { .set_dispatch<Let>([](const Let *op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value); Expr value = m->Mutate(op->value);
Expr body = m->mutate(op->body); Expr body = m->Mutate(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return e; return e;
...@@ -214,8 +188,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -214,8 +188,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<LetStmt>([](const LetStmt *op, const Stmt& s, IRMutator* m) { .set_dispatch<LetStmt>([](const LetStmt *op, const Stmt& s, IRMutator* m) {
Expr value = m->mutate(op->value); Expr value = m->Mutate(op->value);
Stmt body = m->mutate(op->body); Stmt body = m->Mutate(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return s; return s;
...@@ -224,8 +198,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -224,8 +198,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
} }
}) })
.set_dispatch<AssertStmt>([](const AssertStmt *op, const Stmt& s, IRMutator* m) { .set_dispatch<AssertStmt>([](const AssertStmt *op, const Stmt& s, IRMutator* m) {
Expr condition = m->mutate(op->condition); Expr condition = m->Mutate(op->condition);
Expr message = m->mutate(op->message); Expr message = m->Mutate(op->message);
if (condition.same_as(op->condition) && message.same_as(op->message)) { if (condition.same_as(op->condition) && message.same_as(op->message)) {
return s; return s;
...@@ -234,7 +208,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -234,7 +208,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
} }
}) })
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, const Stmt& s, IRMutator* m) { .set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, const Stmt& s, IRMutator* m) {
Stmt body = m->mutate(op->body); Stmt body = m->Mutate(op->body);
if (body.same_as(op->body)) { if (body.same_as(op->body)) {
return s; return s;
} else { } else {
...@@ -242,9 +216,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -242,9 +216,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
} }
}) })
.set_dispatch<For>([](const For *op, const Stmt& s, IRMutator* m) { .set_dispatch<For>([](const For *op, const Stmt& s, IRMutator* m) {
Expr min = m->mutate(op->min); Expr min = m->Mutate(op->min);
Expr extent = m->mutate(op->extent); Expr extent = m->Mutate(op->extent);
Stmt body = m->mutate(op->body); Stmt body = m->Mutate(op->body);
if (min.same_as(op->min) && if (min.same_as(op->min) &&
extent.same_as(op->extent) && extent.same_as(op->extent) &&
body.same_as(op->body)) { body.same_as(op->body)) {
...@@ -255,8 +229,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -255,8 +229,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
} }
}) })
.set_dispatch<Store>([](const Store *op, const Stmt& s, IRMutator* m) { .set_dispatch<Store>([](const Store *op, const Stmt& s, IRMutator* m) {
Expr value = m->mutate(op->value); Expr value = m->Mutate(op->value);
Expr index = m->mutate(op->index); Expr index = m->Mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) { if (value.same_as(op->value) && index.same_as(op->index)) {
return s; return s;
} else { } else {
...@@ -276,14 +250,14 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -276,14 +250,14 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
std::vector<Expr> new_extents; std::vector<Expr> new_extents;
bool all_extents_unmodified = true; bool all_extents_unmodified = true;
for (size_t i = 0; i < op->extents.size(); i++) { for (size_t i = 0; i < op->extents.size(); i++) {
new_extents.push_back(m->mutate(op->extents[i])); new_extents.push_back(m->Mutate(op->extents[i]));
all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
} }
Stmt body = m->mutate(op->body); Stmt body = m->Mutate(op->body);
Expr condition = m->mutate(op->condition); Expr condition = m->Mutate(op->condition);
Expr new_expr; Expr new_expr;
if (op->new_expr.defined()) { if (op->new_expr.defined()) {
new_expr = m->mutate(op->new_expr); new_expr = m->Mutate(op->new_expr);
} }
if (all_extents_unmodified && if (all_extents_unmodified &&
body.same_as(op->body) && body.same_as(op->body) &&
...@@ -308,16 +282,16 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -308,16 +282,16 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
for (size_t i = 0; i < op->bounds.size(); i++) { for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min; Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent; Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->mutate(old_min); Expr new_min = m->Mutate(old_min);
Expr new_extent = m->mutate(old_extent); Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true; if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true; if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back( new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent)); Range::make_by_min_extent(new_min, new_extent));
} }
Stmt body = m->mutate(op->body); Stmt body = m->Mutate(op->body);
Expr condition = m->mutate(op->condition); Expr condition = m->Mutate(op->condition);
if (!bounds_changed && if (!bounds_changed &&
body.same_as(op->body) && body.same_as(op->body) &&
condition.same_as(op->condition)) { condition.same_as(op->condition)) {
...@@ -328,8 +302,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -328,8 +302,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
} }
}) })
.set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) { .set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) {
Stmt first = m->mutate(op->first); Stmt first = m->Mutate(op->first);
Stmt rest = m->mutate(op->rest); Stmt rest = m->Mutate(op->rest);
if (first.same_as(op->first) && if (first.same_as(op->first) &&
rest.same_as(op->rest)) { rest.same_as(op->rest)) {
return s; return s;
...@@ -338,9 +312,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -338,9 +312,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
} }
}) })
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) { .set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->mutate(op->condition); Expr condition = m->Mutate(op->condition);
Stmt then_case = m->mutate(op->then_case); Stmt then_case = m->Mutate(op->then_case);
Stmt else_case = m->mutate(op->else_case); Stmt else_case = m->Mutate(op->else_case);
if (condition.same_as(op->condition) && if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
...@@ -350,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -350,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
} }
}) })
.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) { .set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) {
Expr v = m->mutate(op->value); Expr v = m->Mutate(op->value);
if (v.same_as(op->value)) { if (v.same_as(op->value)) {
return s; return s;
} else { } else {
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
namespace tvm {
namespace ir {
namespace {
// visitor to implement apply
class IRSubstitute : public IRMutator {
public:
Expr Mutate(Expr expr) final {
const IRNode* v = expr.get();
if (v != nullptr) {
auto it = replacements_.find(v);
if (it != replacements_.end()) {
return it->second;
}
}
return IRMutator::Mutate(expr);
}
explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements)
: replacements_(replacements) {}
private:
const std::unordered_map<const IRNode*, Expr>& replacements_;
};
} // namespace
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) {
return IRSubstitute(replacements).Mutate(expr);
}
} // namespace ir
} // namespace tvm
...@@ -14,10 +14,10 @@ class IRApplyVisit : public IRVisitor { ...@@ -14,10 +14,10 @@ class IRApplyVisit : public IRVisitor {
public: public:
explicit IRApplyVisit(std::function<void(const IRNodeRef&)> f) : f_(f) {} explicit IRApplyVisit(std::function<void(const IRNodeRef&)> f) : f_(f) {}
void visit(const IRNodeRef& node) final { void Visit(const IRNodeRef& node) final {
if (visited_.count(node.get()) != 0) return; if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get()); visited_.insert(node.get());
IRVisitor::visit(node); IRVisitor::Visit(node);
f_(node); f_(node);
} }
...@@ -25,18 +25,18 @@ class IRApplyVisit : public IRVisitor { ...@@ -25,18 +25,18 @@ class IRApplyVisit : public IRVisitor {
std::function<void(const IRNodeRef&)> f_; std::function<void(const IRNodeRef&)> f_;
std::unordered_set<const Node*> visited_; std::unordered_set<const Node*> visited_;
}; };
} // namespace } // namespace
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
}
IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*) IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst; static FVisit inst; return inst;
} }
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) {
IRApplyVisit v(fvisit);
v.visit(node);
}
// namespace to register the functors. // namespace to register the functors.
namespace { namespace {
...@@ -47,22 +47,22 @@ void NoOp(const IRNodeRef& n, IRVisitor* v) { ...@@ -47,22 +47,22 @@ void NoOp(const IRNodeRef& n, IRVisitor* v) {
inline void VisitArray(Array<Expr> arr, IRVisitor* v) { inline void VisitArray(Array<Expr> arr, IRVisitor* v) {
for (size_t i = 0; i < arr.size(); i++) { for (size_t i = 0; i < arr.size(); i++) {
v->visit(arr[i]); v->Visit(arr[i]);
} }
} }
inline void VisitRDom(RDomain rdom, IRVisitor* v) { inline void VisitRDom(RDomain rdom, IRVisitor* v) {
for (size_t i = 0; i < rdom->domain.size(); i++) { for (size_t i = 0; i < rdom->domain.size(); i++) {
Range r = rdom->domain[i]; Range r = rdom->domain[i];
v->visit(r->min); v->Visit(r->min);
v->visit(r->extent); v->Visit(r->extent);
} }
} }
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) { .set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v); VisitRDom(op->rdom, v);
v->visit(op->source); v->Visit(op->source);
}); });
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
...@@ -74,14 +74,14 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -74,14 +74,14 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Cast>([](const Cast* op, IRVisitor* v) { .set_dispatch<Cast>([](const Cast* op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
}); });
// binary operator // binary operator
template<typename T> template<typename T>
inline void Binary(const T* op, IRVisitor* v) { inline void Binary(const T* op, IRVisitor* v) {
v->visit(op->a); v->Visit(op->a);
v->visit(op->b); v->Visit(op->b);
} }
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
...@@ -103,51 +103,51 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -103,51 +103,51 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Not>([](const Not* op, IRVisitor* v) { .set_dispatch<Not>([](const Not* op, IRVisitor* v) {
v->visit(op->a); v->Visit(op->a);
}) })
.set_dispatch<Select>([](const Select *op, IRVisitor* v) { .set_dispatch<Select>([](const Select *op, IRVisitor* v) {
v->visit(op->condition); v->Visit(op->condition);
v->visit(op->true_value); v->Visit(op->true_value);
v->visit(op->false_value); v->Visit(op->false_value);
}) })
.set_dispatch<Load>([](const Load *op, IRVisitor* v) { .set_dispatch<Load>([](const Load *op, IRVisitor* v) {
v->visit(op->index); v->Visit(op->index);
}) })
.set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) { .set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
v->visit(op->base); v->Visit(op->base);
v->visit(op->stride); v->Visit(op->stride);
}) })
.set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) { .set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
}) })
.set_dispatch<Call>([](const Call *op, IRVisitor* v) { .set_dispatch<Call>([](const Call *op, IRVisitor* v) {
VisitArray(op->args, v); VisitArray(op->args, v);
}) })
.set_dispatch<Let>([](const Let *op, IRVisitor* v) { .set_dispatch<Let>([](const Let *op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
v->visit(op->body); v->Visit(op->body);
}); });
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<LetStmt>([](const LetStmt *op, IRVisitor* v) { .set_dispatch<LetStmt>([](const LetStmt *op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
v->visit(op->body); v->Visit(op->body);
}) })
.set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) { .set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) {
v->visit(op->condition); v->Visit(op->condition);
v->visit(op->message); v->Visit(op->message);
}) })
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) { .set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) {
v->visit(op->body); v->Visit(op->body);
}) })
.set_dispatch<For>([](const For *op, IRVisitor* v) { .set_dispatch<For>([](const For *op, IRVisitor* v) {
v->visit(op->min); v->Visit(op->min);
v->visit(op->extent); v->Visit(op->extent);
v->visit(op->body); v->Visit(op->body);
}) })
.set_dispatch<Store>([](const Store *op, IRVisitor* v) { .set_dispatch<Store>([](const Store *op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
v->visit(op->index); v->Visit(op->index);
}) })
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) { .set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v); VisitArray(op->args, v);
...@@ -155,36 +155,36 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -155,36 +155,36 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
}) })
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) { .set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) { for (size_t i = 0; i < op->extents.size(); i++) {
v->visit(op->extents[i]); v->Visit(op->extents[i]);
} }
v->visit(op->body); v->Visit(op->body);
v->visit(op->condition); v->Visit(op->condition);
if (op->new_expr.defined()) { if (op->new_expr.defined()) {
v->visit(op->new_expr); v->Visit(op->new_expr);
} }
}) })
.set_dispatch<Free>(NoOp) .set_dispatch<Free>(NoOp)
.set_dispatch<Realize>([](const Realize *op, IRVisitor* v) { .set_dispatch<Realize>([](const Realize *op, IRVisitor* v) {
// Mutate the bounds // Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) { for (size_t i = 0; i < op->bounds.size(); i++) {
v->visit(op->bounds[i]->min); v->Visit(op->bounds[i]->min);
v->visit(op->bounds[i]->extent); v->Visit(op->bounds[i]->extent);
} }
v->visit(op->body); v->Visit(op->body);
v->visit(op->condition); v->Visit(op->condition);
}) })
.set_dispatch<Block>([](const Block *op, IRVisitor* v) { .set_dispatch<Block>([](const Block *op, IRVisitor* v) {
v->visit(op->first); v->Visit(op->first);
v->visit(op->rest); v->Visit(op->rest);
}) })
.set_dispatch<IfThenElse>([](const IfThenElse *op, IRVisitor* v) { .set_dispatch<IfThenElse>([](const IfThenElse *op, IRVisitor* v) {
v->visit(op->condition); v->Visit(op->condition);
v->visit(op->then_case); v->Visit(op->then_case);
v->visit(op->else_case); v->Visit(op->else_case);
}) })
.set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) { .set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) {
v->visit(op->value); v->Visit(op->value);
}); });
} // namespace } // namespace
......
...@@ -13,10 +13,10 @@ class IRVar2Const : public IRMutator { ...@@ -13,10 +13,10 @@ class IRVar2Const : public IRMutator {
public: public:
VarExpr var; VarExpr var;
int int_val; int int_val;
Expr mutate(Expr expr) final { Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr(); static const FMutateExpr& f = IRVar2Const::vtable_expr();
return (f.can_dispatch(expr) ? return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr)); f(expr, expr, this) : IRMutator::Mutate(expr));
} }
static FMutateExpr &vtable_expr(); static FMutateExpr &vtable_expr();
}; };
...@@ -46,31 +46,12 @@ TEST(IRMutator, Basic) { ...@@ -46,31 +46,12 @@ TEST(IRMutator, Basic) {
IRVar2Const mu; IRVar2Const mu;
mu.var = y; mu.var = y;
mu.int_val = 10; mu.int_val = 10;
auto zz = mu.mutate(z); auto zz = mu.Mutate(z);
std::ostringstream os; std::ostringstream os;
os << zz; os << zz;
CHECK(os.str() == "(x + 10)"); CHECK(os.str() == "(x + 10)");
} }
TEST(IRMutator, Substitute) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_pass.h>
TEST(IRPass, Substitute) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = ir::Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = ir::Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
TEST(IRVisitor, CountVar) { TEST(IRVisitor, CountVar) {
using namespace Halide::Internal; using namespace Halide::Internal;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment