Commit 8e04361c by tqchen

Refactor IR Pass

parent ff6b8d82
Subproject commit 89b7939957d66a37dd6083ad6b09a5644e73fd8b
Subproject commit 4becbde67c8aa565941b02648cea90f50211f8dc
......@@ -27,6 +27,7 @@ using Halide::abs;
using Halide::select;
using Halide::Expr;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt;
class Var : public Halide::VarExpr {
......
......@@ -29,7 +29,7 @@ class IRMutator {
* \brief mutate expression
* \return the mutated expr
*/
virtual Expr mutate(Expr expr) {
virtual Expr Mutate(Expr expr) {
static const FMutateExpr& f = vtable_expr();
return f(expr, expr, this);
}
......@@ -37,7 +37,7 @@ class IRMutator {
* \brief mutate expression
* \return the mutated stmt
*/
virtual Stmt mutate(Stmt stmt) {
virtual Stmt Mutate(Stmt stmt) {
static const FMutateStmt& f = vtable_stmt();
return f(stmt, stmt, this);
}
......@@ -58,28 +58,21 @@ class IRMutator {
*/
class IRMutatorExample : public IRMutator {
public:
Expr mutate(Expr expr) final {
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr));
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt mutate(Stmt stmt) final {
Stmt Mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::mutate(stmt));
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
/*!
* \brief Substitute occurance of IRNode to be expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.h
* \brief Collection of IR pass functions and visit functions
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <tvm/ir_node.h>
#include <unordered_map>
#include "./expr.h"
namespace tvm {
namespace ir {
/*!
* \brief Substitute occurance of IRNode in expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
} // namespace ir
} // namespace tvm
#endif // TVM_IR_PASS_H_
......@@ -24,7 +24,7 @@ class IRVisitor {
/*!
* \brief recursively visit an IR node
*/
virtual void visit(const IRNodeRef& node) {
virtual void Visit(const IRNodeRef& node) {
static const FVisit& f = vtable();
if (node.defined()) f(node, this);
}
......
......@@ -101,7 +101,7 @@ class Tensor : public FunctionRef {
};
/*! \brief Node to represent a tensor */
class TensorNode : public Node {
class TensorNode : public FunctionBaseNode {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
......@@ -125,6 +125,12 @@ class TensorNode : public Node {
v->Visit("dim_var", &dim_var);
v->Visit("source", &source);
}
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape,
std::string name,
Type dtype,
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
namespace tvm {
namespace ir {
namespace {
// visitor to implement apply
class IRSubstitute : public IRMutator {
public:
Expr Mutate(Expr expr) final {
const IRNode* v = expr.get();
if (v != nullptr) {
auto it = replacements_.find(v);
if (it != replacements_.end()) {
return it->second;
}
}
return IRMutator::Mutate(expr);
}
explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements)
: replacements_(replacements) {}
private:
const std::unordered_map<const IRNode*, Expr>& replacements_;
};
} // namespace
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) {
return IRSubstitute(replacements).Mutate(expr);
}
} // namespace ir
} // namespace tvm
......@@ -14,10 +14,10 @@ class IRApplyVisit : public IRVisitor {
public:
explicit IRApplyVisit(std::function<void(const IRNodeRef&)> f) : f_(f) {}
void visit(const IRNodeRef& node) final {
void Visit(const IRNodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
IRVisitor::visit(node);
IRVisitor::Visit(node);
f_(node);
}
......@@ -25,18 +25,18 @@ class IRApplyVisit : public IRVisitor {
std::function<void(const IRNodeRef&)> f_;
std::unordered_set<const Node*> visited_;
};
} // namespace
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
}
IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst;
}
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) {
IRApplyVisit v(fvisit);
v.visit(node);
}
// namespace to register the functors.
namespace {
......@@ -47,22 +47,22 @@ void NoOp(const IRNodeRef& n, IRVisitor* v) {
inline void VisitArray(Array<Expr> arr, IRVisitor* v) {
for (size_t i = 0; i < arr.size(); i++) {
v->visit(arr[i]);
v->Visit(arr[i]);
}
}
inline void VisitRDom(RDomain rdom, IRVisitor* v) {
for (size_t i = 0; i < rdom->domain.size(); i++) {
Range r = rdom->domain[i];
v->visit(r->min);
v->visit(r->extent);
v->Visit(r->min);
v->Visit(r->extent);
}
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v);
v->visit(op->source);
v->Visit(op->source);
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
......@@ -74,14 +74,14 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Cast>([](const Cast* op, IRVisitor* v) {
v->visit(op->value);
v->Visit(op->value);
});
// binary operator
template<typename T>
inline void Binary(const T* op, IRVisitor* v) {
v->visit(op->a);
v->visit(op->b);
v->Visit(op->a);
v->Visit(op->b);
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
......@@ -103,51 +103,51 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Not>([](const Not* op, IRVisitor* v) {
v->visit(op->a);
v->Visit(op->a);
})
.set_dispatch<Select>([](const Select *op, IRVisitor* v) {
v->visit(op->condition);
v->visit(op->true_value);
v->visit(op->false_value);
v->Visit(op->condition);
v->Visit(op->true_value);
v->Visit(op->false_value);
})
.set_dispatch<Load>([](const Load *op, IRVisitor* v) {
v->visit(op->index);
v->Visit(op->index);
})
.set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
v->visit(op->base);
v->visit(op->stride);
v->Visit(op->base);
v->Visit(op->stride);
})
.set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
v->visit(op->value);
v->Visit(op->value);
})
.set_dispatch<Call>([](const Call *op, IRVisitor* v) {
VisitArray(op->args, v);
})
.set_dispatch<Let>([](const Let *op, IRVisitor* v) {
v->visit(op->value);
v->visit(op->body);
v->Visit(op->value);
v->Visit(op->body);
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<LetStmt>([](const LetStmt *op, IRVisitor* v) {
v->visit(op->value);
v->visit(op->body);
v->Visit(op->value);
v->Visit(op->body);
})
.set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) {
v->visit(op->condition);
v->visit(op->message);
v->Visit(op->condition);
v->Visit(op->message);
})
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) {
v->visit(op->body);
v->Visit(op->body);
})
.set_dispatch<For>([](const For *op, IRVisitor* v) {
v->visit(op->min);
v->visit(op->extent);
v->visit(op->body);
v->Visit(op->min);
v->Visit(op->extent);
v->Visit(op->body);
})
.set_dispatch<Store>([](const Store *op, IRVisitor* v) {
v->visit(op->value);
v->visit(op->index);
v->Visit(op->value);
v->Visit(op->index);
})
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v);
......@@ -155,36 +155,36 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
})
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) {
v->visit(op->extents[i]);
v->Visit(op->extents[i]);
}
v->visit(op->body);
v->visit(op->condition);
v->Visit(op->body);
v->Visit(op->condition);
if (op->new_expr.defined()) {
v->visit(op->new_expr);
v->Visit(op->new_expr);
}
})
.set_dispatch<Free>(NoOp)
.set_dispatch<Realize>([](const Realize *op, IRVisitor* v) {
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
v->visit(op->bounds[i]->min);
v->visit(op->bounds[i]->extent);
v->Visit(op->bounds[i]->min);
v->Visit(op->bounds[i]->extent);
}
v->visit(op->body);
v->visit(op->condition);
v->Visit(op->body);
v->Visit(op->condition);
})
.set_dispatch<Block>([](const Block *op, IRVisitor* v) {
v->visit(op->first);
v->visit(op->rest);
v->Visit(op->first);
v->Visit(op->rest);
})
.set_dispatch<IfThenElse>([](const IfThenElse *op, IRVisitor* v) {
v->visit(op->condition);
v->visit(op->then_case);
v->visit(op->else_case);
v->Visit(op->condition);
v->Visit(op->then_case);
v->Visit(op->else_case);
})
.set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) {
v->visit(op->value);
v->Visit(op->value);
});
} // namespace
......
......@@ -13,10 +13,10 @@ class IRVar2Const : public IRMutator {
public:
VarExpr var;
int int_val;
Expr mutate(Expr expr) final {
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr));
f(expr, expr, this) : IRMutator::Mutate(expr));
}
static FMutateExpr &vtable_expr();
};
......@@ -46,31 +46,12 @@ TEST(IRMutator, Basic) {
IRVar2Const mu;
mu.var = y;
mu.int_val = 10;
auto zz = mu.mutate(z);
auto zz = mu.Mutate(z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 10)");
}
TEST(IRMutator, Substitute) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_pass.h>
TEST(IRPass, Substitute) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = ir::Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = ir::Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
......@@ -2,6 +2,7 @@
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
TEST(IRVisitor, CountVar) {
using namespace Halide::Internal;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment