Commit be8de13f by tqchen

Enable IRFunctor based IRMutator

parent 0a392dd0
Subproject commit ec84af1359c841df622f683048968348381e328a Subproject commit 89b7939957d66a37dd6083ad6b09a5644e73fd8b
...@@ -36,6 +36,8 @@ class Range : public Halide::IR::Range { ...@@ -36,6 +36,8 @@ class Range : public Halide::IR::Range {
* \param end The end of the range. * \param end The end of the range.
*/ */
Range(Expr begin, Expr end); Range(Expr begin, Expr end);
static Range make_with_min_extent(Expr min, Expr extent);
}; };
/*! \brief Domain is a multi-dimensional range */ /*! \brief Domain is a multi-dimensional range */
...@@ -74,6 +76,8 @@ class RDomain : public NodeRef { ...@@ -74,6 +76,8 @@ class RDomain : public NodeRef {
inline Var i0() const { inline Var i0() const {
return index(0); return index(0);
} }
// low level constructor
static RDomain make(Array<Var> index, Domain domain);
}; };
/*! \brief use RDom as alias of RDomain */ /*! \brief use RDom as alias of RDomain */
...@@ -88,8 +92,8 @@ class RDomainNode : public Node { ...@@ -88,8 +92,8 @@ class RDomainNode : public Node {
Domain domain; Domain domain;
/*! \brief constructor */ /*! \brief constructor */
RDomainNode() {} RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain) RDomainNode(Array<Var> index, Domain domain)
: index(std::move(index)), domain(std::move(domain)) { : index(index), domain(domain) {
} }
const char* type_key() const override { const char* type_key() const override {
return "RDomain"; return "RDomain";
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <ir/Expr.h> #include <ir/Expr.h>
#include <ir/IROperator.h> #include <ir/IROperator.h>
#include <type_traits> #include <string>
#include "./base.h" #include "./base.h"
namespace tvm { namespace tvm {
...@@ -28,7 +28,12 @@ using Halide::select; ...@@ -28,7 +28,12 @@ using Halide::select;
using Halide::Expr; using Halide::Expr;
using Halide::Internal::Stmt; using Halide::Internal::Stmt;
using Var = Halide::VarExpr;
class Var : public Halide::VarExpr {
public:
explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {}
};
} // namespace tvm } // namespace tvm
#endif // TVM_EXPR_H_ #endif // TVM_EXPR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file ir_mutator.h
* \brief Defines general IRMutation pass
*/
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_
#include <tvm/ir_node.h>
#include "./expr.h"
namespace tvm {
namespace ir {
/*!
* \brief a base class for mutator to iterative mutate the IR
*
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
* This enables easy extensions of possible new IRNode.
* It also makes changing return types easier.
*
* \note If you want to return a different type other than Expr and Stmt,
* Simply following the same pattern as IRMutator and create a seperate class.
* \sa IRFunctor
*/
class IRMutator {
public:
/*!
* \brief mutate expression
* \return the mutated expr
*/
virtual Expr mutate(Expr expr) {
static const FMutateExpr& f = vtable_expr();
return f(expr, expr, this);
}
/*!
* \brief mutate expression
* \return the mutated stmt
*/
virtual Stmt mutate(Stmt stmt) {
static const FMutateStmt& f = vtable_stmt();
return f(stmt, stmt, this);
}
/*! \brief destructor */
virtual ~IRMutator() {}
/*! \brief functor type of expr mutation */
using FMutateExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRMutator*)>;
/*! \brief functor type of stmt mutation */
using FMutateStmt = IRFunctor<Stmt(const IRNodeRef&, const Stmt&, IRMutator*)>;
/*! \return internal vtable of expr */
static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
/*!
* \brief templatized base class of subclass of IRMutator
*
* Use "curiously recurring template pattern" to implement mutate for you.
* Child class need to declare IRMutatorBase<T>::vtable_expr and IRMutatorBase<T>::vtable_stmt
*
* \note This only implement direct subclass from IRMutator, similar code
* can be created to implement deeper subclassing when needed.
*/
class IRMutatorExample : public IRMutator {
public:
Expr mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr));
}
Stmt mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
...@@ -12,6 +12,10 @@ Range::Range(Expr begin, Expr end) ...@@ -12,6 +12,10 @@ Range::Range(Expr begin, Expr end)
// TODO(tqchen) add simplify to end - begin // TODO(tqchen) add simplify to end - begin
} }
Range Range::make_with_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<Halide::IR::RangeNode>(min, extent));
}
RDomain::RDomain(Domain domain) { RDomain::RDomain(Domain domain) {
std::vector<Var> index; std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) { for (size_t i = 0; i < domain.size(); ++i) {
...@@ -24,6 +28,10 @@ RDomain::RDomain(Domain domain) { ...@@ -24,6 +28,10 @@ RDomain::RDomain(Domain domain) {
std::move(idx), std::move(domain)); std::move(idx), std::move(domain));
} }
RDomain RDomain::make(Array<Var> index, Domain domain) {
return RDomain(std::make_shared<RDomainNode>(index, domain));
}
TVM_REGISTER_NODE_TYPE(RDomainNode); TVM_REGISTER_NODE_TYPE(RDomainNode);
} // namespace tvm } // namespace tvm
...@@ -20,7 +20,7 @@ namespace Internal { ...@@ -20,7 +20,7 @@ namespace Internal {
using tvm::ir::Reduce; using tvm::ir::Reduce;
template<> template<>
void ExprNode<Reduce>::accept(IRVisitor *v) const { void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with IRVisitor yet"; LOG(FATAL) << "Reduce do not work with IRVisitor yet";
} }
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_mutator.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
namespace tvm {
namespace ir {
IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
static FMutateStmt inst; return inst;
}
// namespace to register the functors.
namespace {
using namespace Halide::Internal;
// const expr
inline Expr ReturnSelfExpr(const IRNodeRef&, const Expr& e, IRMutator*) {
return e;
}
inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
std::vector<Expr> new_arr(arr.size());
bool changed = false;
for (size_t i = 0; i < arr.size(); i++) {
Expr old_elem = arr[i];
Expr new_elem = m->mutate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
}
if (!changed) {
return arr;
} else {
return Array<Expr>(new_arr);
}
}
inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
std::vector<Range> new_dom(rdom->domain.size());
bool changed = false;
for (size_t i = 0; i < rdom->domain.size(); i++) {
Range r = rdom->domain[i];
Expr new_min = m->mutate(r->min);
Expr new_extent = m->mutate(r->extent);
if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = Range::make_with_min_extent(new_min, new_extent);
}
if (!changed) {
return rdom;
} else {
return RDomain::make(rdom->index, Domain(new_dom));
}
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
RDomain new_rdom = MutateRDom(op->rdom, m);
Expr new_source = m->mutate(op->source);
if (op->rdom.same_as(new_rdom) &&
op->source.same_as(new_source)) {
return e;
} else {
return Reduce::make(op->op, new_source, new_rdom);
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<IntImm>(ReturnSelfExpr)
.set_dispatch<UIntImm>(ReturnSelfExpr)
.set_dispatch<FloatImm>(ReturnSelfExpr)
.set_dispatch<StringImm>(ReturnSelfExpr)
.set_dispatch<Variable>(ReturnSelfExpr);
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type, value);
}
});
// binary operator
template<typename T>
inline Expr Binary(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->mutate(op->a);
Expr b = m->mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
} else {
return T::make(a, b);
}
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Add>(Binary<Add>)
.set_dispatch<Sub>(Binary<Sub>)
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
.set_dispatch<Max>(Binary<Max>)
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>)
.set_dispatch<LT>(Binary<LT>)
.set_dispatch<LE>(Binary<LE>)
.set_dispatch<GT>(Binary<GT>)
.set_dispatch<GE>(Binary<GE>)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Not>([](const Not* op, const Expr& e, IRMutator* m) {
Expr a = m->mutate(op->a);
if (a.same_as(op->a)) {
return e;
} else {
return Not::make(a);
}
})
.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) {
Expr cond = m->mutate(op->condition);
Expr t = m->mutate(op->true_value);
Expr f = m->mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
return Select::make(cond, t, f);
}
})
.set_dispatch<Load>([](const Load *op, const Expr& e, IRMutator* m) {
Expr index = m->mutate(op->index);
if (index.same_as(op->index)) {
return e;
} else {
return Load::make(op->type, op->buffer_var, index);
}
})
.set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) {
Expr base = m->mutate(op->base);
Expr stride = m->mutate(op->stride);
if (base.same_as(op->base) &&
stride.same_as(op->stride)) {
return e;
} else {
return Ramp::make(base, stride, op->lanes);
}
})
.set_dispatch<Broadcast>([](const Broadcast *op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Broadcast::make(value, op->lanes);
}
})
.set_dispatch<Call>([](const Call *op, const Expr& e, IRMutator* m) {
auto new_args = MutateArray(op->args, m);
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(op->type, op->name, new_args, op->call_type,
op->func, op->value_index);
}
})
.set_dispatch<Let>([](const Let *op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value);
Expr body = m->mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
} else {
return Let::make(op->var, value, body);
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<LetStmt>([](const LetStmt *op, const Stmt& s, IRMutator* m) {
Expr value = m->mutate(op->value);
Stmt body = m->mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
})
.set_dispatch<AssertStmt>([](const AssertStmt *op, const Stmt& s, IRMutator* m) {
Expr condition = m->mutate(op->condition);
Expr message = m->mutate(op->message);
if (condition.same_as(op->condition) && message.same_as(op->message)) {
return s;
} else {
return AssertStmt::make(condition, message);
}
})
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, const Stmt& s, IRMutator* m) {
Stmt body = m->mutate(op->body);
if (body.same_as(op->body)) {
return s;
} else {
return ProducerConsumer::make(op->func, op->is_producer, body);
}
})
.set_dispatch<For>([](const For *op, const Stmt& s, IRMutator* m) {
Expr min = m->mutate(op->min);
Expr extent = m->mutate(op->extent);
Stmt body = m->mutate(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
} else {
return For::make(
op->loop_var, min, extent, op->for_type, op->device_api, body);
}
})
.set_dispatch<Store>([](const Store *op, const Stmt& s, IRMutator* m) {
Expr value = m->mutate(op->value);
Expr index = m->mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
} else {
return Store::make(op->buffer_var, value, index);
}
})
.set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) {
auto new_args = MutateArray(op->args, m);
auto new_values = MutateArray(op->values, m);
if (op->args.same_as(new_args) && op->values.same_as(new_values)) {
return s;
} else {
return Provide::make(op->func, new_values, new_args);
}
})
.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
std::vector<Expr> new_extents;
bool all_extents_unmodified = true;
for (size_t i = 0; i < op->extents.size(); i++) {
new_extents.push_back(m->mutate(op->extents[i]));
all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
}
Stmt body = m->mutate(op->body);
Expr condition = m->mutate(op->condition);
Expr new_expr;
if (op->new_expr.defined()) {
new_expr = m->mutate(op->new_expr);
}
if (all_extents_unmodified &&
body.same_as(op->body) &&
condition.same_as(op->condition) &&
new_expr.same_as(op->new_expr)) {
return s;
} else {
return Allocate::make(
op->buffer_var, op->type,
new_extents, condition, body,
new_expr, op->free_function);
}
})
.set_dispatch<Free>([](const Free *op, const Stmt& s, IRMutator* m) {
return s;
})
.set_dispatch<Realize>([](const Realize *op, const Stmt& s, IRMutator* m) {
Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->mutate(old_min);
Expr new_extent = m->mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
Stmt body = m->mutate(op->body);
Expr condition = m->mutate(op->condition);
if (!bounds_changed &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
return s;
} else {
return Realize::make(op->func, op->types, new_bounds,
condition, body);
}
})
.set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) {
Stmt first = m->mutate(op->first);
Stmt rest = m->mutate(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
})
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->mutate(op->condition);
Stmt then_case = m->mutate(op->then_case);
Stmt else_case = m->mutate(op->else_case);
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
})
.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) {
Expr v = m->mutate(op->value);
if (v.same_as(op->value)) {
return s;
} else {
return Evaluate::make(v);
}
});
} // namespace
} // namespace ir
} // namespace tvm
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_mutator.h>
namespace {
using namespace tvm::ir;
using namespace Halide::Internal;
using namespace Halide;
// replace variable to constant
class IRVar2Const : public IRMutator {
public:
VarExpr var;
int int_val;
Expr mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr));
}
static FMutateExpr &vtable_expr();
};
// implement vtable
IRMutator::FMutateExpr &IRVar2Const::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) {
IRVar2Const* vm = static_cast<IRVar2Const*>(m);
if (e.same_as(vm->var)) {
return IntImm::make(Int(32), vm->int_val);
} else {
return e;
}
});
} // namespace
TEST(IRMutator, Basic) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
IRVar2Const mu;
mu.var = y;
mu.int_val = 10;
auto zz = mu.mutate(z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 10)");
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment