/*! * Copyright (c) 2016 by Contributors * \file ir_mutator.cc */ #include <tvm/ir.h> #include <tvm/ir_mutator.h> #include <tvm/packed_func_ext.h> #include "./ir_util.h" namespace tvm { namespace ir { class IRTransformer final : public IRMutator { public: IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, const std::unordered_set<uint32_t>& only_enable) : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) { } Stmt Mutate(Stmt stmt) final { return MutateInternal<Stmt>(stmt); } Expr Mutate(Expr expr) final { return MutateInternal<Expr>(expr); } private: template<typename T> T MutateInternal(T node) { if (only_enable_.size() && !only_enable_.count(node->type_index())) { return IRMutator::Mutate(node); } if (f_preorder_ != nullptr) { T pre = f_preorder_(node); if (pre.defined()) return pre; } node = IRMutator::Mutate(node); if (f_postorder_ != nullptr) { T post = f_postorder_(node); if (post.defined()) return post; } return node; } // The functions const runtime::PackedFunc& f_preorder_; const runtime::PackedFunc& f_postorder_; // type indices enabled. const std::unordered_set<uint32_t>& only_enable_; }; Stmt IRTransform(const Stmt& ir_node, const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, const Array<Expr>& only_enable) { std::unordered_set<uint32_t> only_type_index; for (Expr s : only_enable) { only_type_index.insert(Node::TypeKey2Index(s.as<StringImm>()->value.c_str())); } return IRTransformer(f_preorder, f_postorder, only_type_index) .Mutate(ir_node); } IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) static FMutateExpr inst; return inst; } IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) static FMutateStmt inst; return inst; } inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) { return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); }); } inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) { std::vector<IterVar> new_dom(rdom.size()); bool changed = false; for (size_t i = 0; i < rdom.size(); i++) { IterVar v = rdom[i]; Range r = v->dom; Expr new_min = m->Mutate(r->min); Expr new_extent = m->Mutate(r->extent); if (!r->min.same_as(new_min)) changed = true; if (!r->extent.same_as(new_extent)) changed = true; new_dom[i] = IterVarNode::make( Range::make_by_min_extent(new_min, new_extent), v->var, v->iter_type, v->thread_tag); } if (!changed) { return rdom; } else { return Array<IterVar>(new_dom); } } // Mutate Stmt #define DISPATCH_TO_MUTATE_STMT(OP) \ set_dispatch<OP>([](const OP* op, const Stmt& s, IRMutator* m) { \ return m->Mutate_(op, s); \ }) Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { Expr value = this->Mutate(op->value); Stmt body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return s; } else { return AttrStmt::make(op->node, op->attr_key, value, body); } } Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { Expr value = this->Mutate(op->value); Stmt body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return s; } else { return LetStmt::make(op->var, value, body); } } Stmt IRMutator::Mutate_(const For *op, const Stmt& s) { Expr min = this->Mutate(op->min); Expr extent = this->Mutate(op->extent); Stmt body = this->Mutate(op->body); if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { return s; } else { return For::make( op->loop_var, min, extent, op->for_type, op->device_api, body); } } Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) { IRMutator* m = this; std::vector<Expr> new_extents; bool all_extents_unmodified = true; for (size_t i = 0; i < op->extents.size(); i++) { new_extents.push_back(m->Mutate(op->extents[i])); all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); } Stmt body = m->Mutate(op->body); Expr condition = m->Mutate(op->condition); Expr new_expr; if (op->new_expr.defined()) { new_expr = m->Mutate(op->new_expr); } if (all_extents_unmodified && body.same_as(op->body) && condition.same_as(op->condition) && new_expr.same_as(op->new_expr)) { return s; } else { return Allocate::make( op->buffer_var, op->type, new_extents, condition, body, new_expr, op->free_function); } } Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { Expr condition = this->Mutate(op->condition); Stmt then_case = this->Mutate(op->then_case); Stmt else_case; if (op->else_case.defined()) { else_case = this->Mutate(op->else_case); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return s; } else { return IfThenElse::make(condition, then_case, else_case); } } Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { Expr value = this->Mutate(op->value); Expr index = this->Mutate(op->index); Expr pred = this->Mutate(op->predicate); if (value.same_as(op->value) && index.same_as(op->index) && pred.same_as(op->predicate)) { return s; } else { return Store::make(op->buffer_var, value, index, pred); } } Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) { auto new_args = MutateArray(op->args, this); auto new_value = this->Mutate(op->value); if (op->args.same_as(new_args) && op->value.same_as(new_value)) { return s; } else { return Provide::make(op->func, op->value_index, new_value, new_args); } } Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { IRMutator* m = this; Halide::Internal::Region new_bounds; bool bounds_changed = false; // Mutate the bounds for (size_t i = 0; i < op->bounds.size(); i++) { Expr old_min = op->bounds[i]->min; Expr old_extent = op->bounds[i]->extent; Expr new_min = m->Mutate(old_min); Expr new_extent = m->Mutate(old_extent); if (!new_min.same_as(old_min)) bounds_changed = true; if (!new_extent.same_as(old_extent)) bounds_changed = true; new_bounds.push_back( Range::make_by_min_extent(new_min, new_extent)); } Stmt body = m->Mutate(op->body); Expr condition = m->Mutate(op->condition); if (!bounds_changed && body.same_as(op->body) && condition.same_as(op->condition)) { return s; } else { return Realize::make(op->func, op->value_index, op->type, new_bounds, condition, body); } } Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) { IRMutator* m = this; Halide::Internal::Region new_bounds; bool bounds_changed = false; // Mutate the bounds for (size_t i = 0; i < op->bounds.size(); i++) { Expr old_min = op->bounds[i]->min; Expr old_extent = op->bounds[i]->extent; Expr new_min = m->Mutate(old_min); Expr new_extent = m->Mutate(old_extent); if (!new_min.same_as(old_min)) bounds_changed = true; if (!new_extent.same_as(old_extent)) bounds_changed = true; new_bounds.push_back( Range::make_by_min_extent(new_min, new_extent)); } if (!bounds_changed) { return s; } else { return Prefetch::make(op->func, op->value_index, op->type, new_bounds); } } Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { Stmt first = this->Mutate(op->first); Stmt rest = this->Mutate(op->rest); if (first.same_as(op->first) && rest.same_as(op->rest)) { return s; } else { return Block::make(first, rest); } } Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) { Expr condition = this->Mutate(op->condition); Expr message = this->Mutate(op->message); Stmt body = this->Mutate(op->body); if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { return s; } else { return AssertStmt::make(condition, message, body); } } Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) { Stmt body = this->Mutate(op->body); if (body.same_as(op->body)) { return s; } else { return ProducerConsumer::make(op->func, op->is_producer, body); } } Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) { Expr v = this->Mutate(op->value); if (v.same_as(op->value)) { return s; } else { return Evaluate::make(v); } } Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { return s; } TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) .DISPATCH_TO_MUTATE_STMT(LetStmt) .DISPATCH_TO_MUTATE_STMT(AttrStmt) .DISPATCH_TO_MUTATE_STMT(IfThenElse) .DISPATCH_TO_MUTATE_STMT(For) .DISPATCH_TO_MUTATE_STMT(Allocate) .DISPATCH_TO_MUTATE_STMT(Store) .DISPATCH_TO_MUTATE_STMT(Free) .DISPATCH_TO_MUTATE_STMT(AssertStmt) .DISPATCH_TO_MUTATE_STMT(ProducerConsumer) .DISPATCH_TO_MUTATE_STMT(Provide) .DISPATCH_TO_MUTATE_STMT(Realize) .DISPATCH_TO_MUTATE_STMT(Block) .DISPATCH_TO_MUTATE_STMT(Evaluate) .DISPATCH_TO_MUTATE_STMT(Prefetch); // Mutate Expr #define DISPATCH_TO_MUTATE_EXPR(OP) \ set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) { \ return m->Mutate_(op, e); \ }) Expr IRMutator::Mutate_(const Variable *op, const Expr& e) { return e; } Expr IRMutator::Mutate_(const Load *op, const Expr& e) { Expr index = this->Mutate(op->index); Expr pred = this->Mutate(op->predicate); if (index.same_as(op->index) && pred.same_as(op->predicate)) { return e; } else { return Load::make(op->type, op->buffer_var, index, pred); } } Expr IRMutator::Mutate_(const Let *op, const Expr& e) { Expr value = this->Mutate(op->value); Expr body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return e; } else { return Let::make(op->var, value, body); } } Expr IRMutator::Mutate_(const Call* op, const Expr& e) { auto new_args = MutateArray(op->args, this); if (op->args.same_as(new_args)) { return e; } else { return Call::make(op->type, op->name, new_args, op->call_type, op->func, op->value_index); } } #define DEFINE_BIOP_EXPR_MUTATE_(OP) \ Expr IRMutator::Mutate_(const OP* op, const Expr& e) { \ Expr a = this->Mutate(op->a); \ Expr b = this->Mutate(op->b); \ if (a.same_as(op->a) && \ b.same_as(op->b)) { \ return e; \ } else { \ return OP::make(a, b); \ } \ } DEFINE_BIOP_EXPR_MUTATE_(Add) DEFINE_BIOP_EXPR_MUTATE_(Sub) DEFINE_BIOP_EXPR_MUTATE_(Mul) DEFINE_BIOP_EXPR_MUTATE_(Div) DEFINE_BIOP_EXPR_MUTATE_(Mod) DEFINE_BIOP_EXPR_MUTATE_(Min) DEFINE_BIOP_EXPR_MUTATE_(Max) DEFINE_BIOP_EXPR_MUTATE_(EQ) DEFINE_BIOP_EXPR_MUTATE_(NE) DEFINE_BIOP_EXPR_MUTATE_(LT) DEFINE_BIOP_EXPR_MUTATE_(LE) DEFINE_BIOP_EXPR_MUTATE_(GT) DEFINE_BIOP_EXPR_MUTATE_(GE) DEFINE_BIOP_EXPR_MUTATE_(And) DEFINE_BIOP_EXPR_MUTATE_(Or) Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { Array<IterVar> new_axis = MutateIterVarArr(op->axis, this); Array<Expr> new_source = MutateArray(op->source, this); Expr new_cond = this->Mutate(op->condition); if (op->axis.same_as(new_axis) && op->source.same_as(new_source) && op->condition.same_as(new_cond)) { return e; } else { return Reduce::make( op->combiner, new_source, new_axis, new_cond, op->value_index); } } Expr IRMutator::Mutate_(const Cast *op, const Expr& e) { Expr value = this->Mutate(op->value); if (value.same_as(op->value)) { return e; } else { return Cast::make(op->type, value); } } Expr IRMutator::Mutate_(const Not *op, const Expr& e) { Expr a = this->Mutate(op->a); if (a.same_as(op->a)) { return e; } else { return Not::make(a); } } Expr IRMutator::Mutate_(const Select *op, const Expr& e) { Expr cond = this->Mutate(op->condition); Expr t = this->Mutate(op->true_value); Expr f = this->Mutate(op->false_value); if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { return e; } else { return Select::make(cond, t, f); } } Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) { Expr base = this->Mutate(op->base); Expr stride = this->Mutate(op->stride); if (base.same_as(op->base) && stride.same_as(op->stride)) { return e; } else { return Ramp::make(base, stride, op->lanes); } } Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { Expr value = this->Mutate(op->value); if (value.same_as(op->value)) { return e; } else { return Broadcast::make(value, op->lanes); } } Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) { auto new_vec = MutateArray(op->vectors, this); if (new_vec.same_as(op->vectors)) { return e; } else { return Shuffle::make(new_vec, op->indices); } } #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \ return e; \ } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .DISPATCH_TO_MUTATE_EXPR(Variable) .DISPATCH_TO_MUTATE_EXPR(Load) .DISPATCH_TO_MUTATE_EXPR(Let) .DISPATCH_TO_MUTATE_EXPR(Call) .DISPATCH_TO_MUTATE_EXPR(Add) .DISPATCH_TO_MUTATE_EXPR(Sub) .DISPATCH_TO_MUTATE_EXPR(Mul) .DISPATCH_TO_MUTATE_EXPR(Div) .DISPATCH_TO_MUTATE_EXPR(Mod) .DISPATCH_TO_MUTATE_EXPR(Min) .DISPATCH_TO_MUTATE_EXPR(Max) .DISPATCH_TO_MUTATE_EXPR(EQ) .DISPATCH_TO_MUTATE_EXPR(NE) .DISPATCH_TO_MUTATE_EXPR(LT) .DISPATCH_TO_MUTATE_EXPR(LE) .DISPATCH_TO_MUTATE_EXPR(GT) .DISPATCH_TO_MUTATE_EXPR(GE) .DISPATCH_TO_MUTATE_EXPR(And) .DISPATCH_TO_MUTATE_EXPR(Or) .DISPATCH_TO_MUTATE_EXPR(Reduce) .DISPATCH_TO_MUTATE_EXPR(Cast) .DISPATCH_TO_MUTATE_EXPR(Not) .DISPATCH_TO_MUTATE_EXPR(Select) .DISPATCH_TO_MUTATE_EXPR(Ramp) .DISPATCH_TO_MUTATE_EXPR(Broadcast) .DISPATCH_TO_MUTATE_EXPR(IntImm) .DISPATCH_TO_MUTATE_EXPR(UIntImm) .DISPATCH_TO_MUTATE_EXPR(FloatImm) .DISPATCH_TO_MUTATE_EXPR(StringImm) .DISPATCH_TO_MUTATE_EXPR(Shuffle); } // namespace ir } // namespace tvm