Commit be8de13f by tqchen

Enable IRFunctor based IRMutator

parent 0a392dd0
Subproject commit ec84af1359c841df622f683048968348381e328a Subproject commit 89b7939957d66a37dd6083ad6b09a5644e73fd8b
...@@ -36,6 +36,8 @@ class Range : public Halide::IR::Range { ...@@ -36,6 +36,8 @@ class Range : public Halide::IR::Range {
* \param end The end of the range. * \param end The end of the range.
*/ */
Range(Expr begin, Expr end); Range(Expr begin, Expr end);
static Range make_with_min_extent(Expr min, Expr extent);
}; };
/*! \brief Domain is a multi-dimensional range */ /*! \brief Domain is a multi-dimensional range */
...@@ -74,6 +76,8 @@ class RDomain : public NodeRef { ...@@ -74,6 +76,8 @@ class RDomain : public NodeRef {
inline Var i0() const { inline Var i0() const {
return index(0); return index(0);
} }
// low level constructor
static RDomain make(Array<Var> index, Domain domain);
}; };
/*! \brief use RDom as alias of RDomain */ /*! \brief use RDom as alias of RDomain */
...@@ -88,8 +92,8 @@ class RDomainNode : public Node { ...@@ -88,8 +92,8 @@ class RDomainNode : public Node {
Domain domain; Domain domain;
/*! \brief constructor */ /*! \brief constructor */
RDomainNode() {} RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain) RDomainNode(Array<Var> index, Domain domain)
: index(std::move(index)), domain(std::move(domain)) { : index(index), domain(domain) {
} }
const char* type_key() const override { const char* type_key() const override {
return "RDomain"; return "RDomain";
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <ir/Expr.h> #include <ir/Expr.h>
#include <ir/IROperator.h> #include <ir/IROperator.h>
#include <type_traits> #include <string>
#include "./base.h" #include "./base.h"
namespace tvm { namespace tvm {
...@@ -28,7 +28,12 @@ using Halide::select; ...@@ -28,7 +28,12 @@ using Halide::select;
using Halide::Expr; using Halide::Expr;
using Halide::Internal::Stmt; using Halide::Internal::Stmt;
using Var = Halide::VarExpr;
class Var : public Halide::VarExpr {
public:
explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {}
};
} // namespace tvm } // namespace tvm
#endif // TVM_EXPR_H_ #endif // TVM_EXPR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file ir_mutator.h
* \brief Defines general IRMutation pass
*/
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_
#include <tvm/ir_node.h>
#include "./expr.h"
namespace tvm {
namespace ir {
/*!
* \brief a base class for mutator to iterative mutate the IR
*
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
* This enables easy extensions of possible new IRNode.
* It also makes changing return types easier.
*
* \note If you want to return a different type other than Expr and Stmt,
* Simply following the same pattern as IRMutator and create a seperate class.
* \sa IRFunctor
*/
class IRMutator {
public:
/*!
* \brief mutate expression
* \return the mutated expr
*/
virtual Expr mutate(Expr expr) {
static const FMutateExpr& f = vtable_expr();
return f(expr, expr, this);
}
/*!
* \brief mutate expression
* \return the mutated stmt
*/
virtual Stmt mutate(Stmt stmt) {
static const FMutateStmt& f = vtable_stmt();
return f(stmt, stmt, this);
}
/*! \brief destructor */
virtual ~IRMutator() {}
/*! \brief functor type of expr mutation */
using FMutateExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRMutator*)>;
/*! \brief functor type of stmt mutation */
using FMutateStmt = IRFunctor<Stmt(const IRNodeRef&, const Stmt&, IRMutator*)>;
/*! \return internal vtable of expr */
static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
/*!
* \brief templatized base class of subclass of IRMutator
*
* Use "curiously recurring template pattern" to implement mutate for you.
* Child class need to declare IRMutatorBase<T>::vtable_expr and IRMutatorBase<T>::vtable_stmt
*
* \note This only implement direct subclass from IRMutator, similar code
* can be created to implement deeper subclassing when needed.
*/
class IRMutatorExample : public IRMutator {
public:
Expr mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr));
}
Stmt mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
...@@ -12,6 +12,10 @@ Range::Range(Expr begin, Expr end) ...@@ -12,6 +12,10 @@ Range::Range(Expr begin, Expr end)
// TODO(tqchen) add simplify to end - begin // TODO(tqchen) add simplify to end - begin
} }
Range Range::make_with_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<Halide::IR::RangeNode>(min, extent));
}
RDomain::RDomain(Domain domain) { RDomain::RDomain(Domain domain) {
std::vector<Var> index; std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) { for (size_t i = 0; i < domain.size(); ++i) {
...@@ -24,6 +28,10 @@ RDomain::RDomain(Domain domain) { ...@@ -24,6 +28,10 @@ RDomain::RDomain(Domain domain) {
std::move(idx), std::move(domain)); std::move(idx), std::move(domain));
} }
RDomain RDomain::make(Array<Var> index, Domain domain) {
return RDomain(std::make_shared<RDomainNode>(index, domain));
}
TVM_REGISTER_NODE_TYPE(RDomainNode); TVM_REGISTER_NODE_TYPE(RDomainNode);
} // namespace tvm } // namespace tvm
...@@ -20,7 +20,7 @@ namespace Internal { ...@@ -20,7 +20,7 @@ namespace Internal {
using tvm::ir::Reduce; using tvm::ir::Reduce;
template<> template<>
void ExprNode<Reduce>::accept(IRVisitor *v) const { void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with IRVisitor yet"; LOG(FATAL) << "Reduce do not work with IRVisitor yet";
} }
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_mutator.h>
namespace {
using namespace tvm::ir;
using namespace Halide::Internal;
using namespace Halide;
// replace variable to constant
class IRVar2Const : public IRMutator {
public:
VarExpr var;
int int_val;
Expr mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr));
}
static FMutateExpr &vtable_expr();
};
// implement vtable
IRMutator::FMutateExpr &IRVar2Const::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) {
IRVar2Const* vm = static_cast<IRVar2Const*>(m);
if (e.same_as(vm->var)) {
return IntImm::make(Int(32), vm->int_val);
} else {
return e;
}
});
} // namespace
TEST(IRMutator, Basic) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
IRVar2Const mu;
mu.var = y;
mu.int_val = 10;
auto zz = mu.mutate(z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 10)");
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment