/*! * Copyright (c) 2017 by Contributors * * \brief Lift specified AttrStmt scope to outer if * the body contains the same scope. * \file lift_attr_scope.cc */ #include <tvm/ir_pass.h> #include <tvm/ir_mutator.h> #include "ir_util.h" namespace tvm { namespace ir { // NOTE: this optimization can only be applied // to a few specified attr keys class AttrScopeLifter : public IRMutator { public: explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {} Stmt Lift(Stmt stmt) { stmt = Mutate(stmt); if (attr_node_.defined()) { stmt = AttrStmt::make( attr_node_, attr_key_, attr_value_, stmt); } return stmt; } // do not go beyond Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as<Allocate>(); if (attr_node_.defined()) { Stmt body = AttrStmt::make( attr_node_, attr_key_, attr_value_, op->body); // undefine them attr_node_ = NodeRef(); attr_value_ = Expr(); return Allocate::make( op->buffer_var, op->type, op->extents, op->condition, body, op->new_expr, op->free_function); } else { return stmt; } } Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { if (op->attr_key == attr_key_) { attr_node_ = op->node; attr_value_ = op->value; return op->body; } else { return IRMutator::Mutate_(op, s); } } Stmt Mutate_(const Block* op, const Stmt& s) final { std::vector<Stmt> seq; FlattenSeq(op->first, &seq); FlattenSeq(op->rest, &seq); seq = MutateSeq(seq); if (seq.size() == 2 && seq[0].same_as(op->first) && seq[1].same_as(op->rest)) { return s; } return MergeSeq(seq); } Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { if (!op->else_case.defined()) { return IRMutator::Mutate_(op, s); } Stmt then_case = this->Mutate(op->then_case); NodeRef first_node; Expr first_value; std::swap(first_node, attr_node_); std::swap(first_value, attr_value_); Stmt else_case = this->Mutate(op->else_case); if (attr_node_.defined() && attr_value_.defined() && first_node.defined() && first_value.defined() && attr_node_.same_as(first_node) && ValueSame(attr_value_, first_value)) { if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return s; } else { return IfThenElse::make(op->condition, then_case, else_case); } } else { if (first_node.defined()) { then_case = AttrStmt::make( first_node, attr_key_, first_value, then_case); } if (attr_node_.defined()) { else_case = AttrStmt::make( attr_node_, attr_key_, attr_value_, else_case); // undefine them attr_node_ = NodeRef(); attr_value_ = Expr(); } if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return s; } else { return IfThenElse::make(op->condition, then_case, else_case); } } } private: void FlattenSeq(Stmt s, std::vector<Stmt>* res) { if (const Block* op = s.as<Block>()) { FlattenSeq(op->first, res); FlattenSeq(op->rest, res); } else if (const ProducerConsumer* op = s.as<ProducerConsumer>()) { if (!op->is_producer) { FlattenSeq(op->body, res); } else { res->emplace_back(s); } } else { res->emplace_back(s); } } std::vector<Stmt> MutateSeq(const std::vector<Stmt>& seq) { std::vector<Stmt> res_seq; NodeRef curr_node; Expr curr_value; Stmt curr_stmt; for (const Stmt & stmt : seq) { attr_node_ = NodeRef(); attr_value_ = Expr(); Stmt rest = this->Mutate(stmt); if (attr_node_.defined() && attr_value_.defined() && curr_node.defined() && curr_value.defined() && attr_node_.same_as(curr_node) && ValueSame(attr_value_, curr_value)) { curr_stmt = Block::make(curr_stmt, rest); } else { if (curr_stmt.defined()) { if (curr_node.defined()) { curr_stmt = AttrStmt::make( curr_node, attr_key_, curr_value, curr_stmt); } res_seq.push_back(curr_stmt); } curr_stmt = rest; curr_node = attr_node_; curr_value = attr_value_; } } if (curr_stmt.defined()) { // keep attr_node_, attr_node_ if (res_seq.size() == 0) { return {curr_stmt}; } if (curr_node.defined()) { curr_stmt = AttrStmt::make( curr_node, attr_key_, curr_value, curr_stmt); } res_seq.push_back(curr_stmt); // reset attr_node_ = NodeRef(); attr_value_ = Expr(); } return res_seq; } // value comparison that also compares content of int constant static bool ValueSame(const Expr& a, const Expr& b) { if (a.same_as(b)) return true; if (a->type_key() != b->type_key()) return false; if (a.type() != b.type()) return false; if (const IntImm* op = a.as<IntImm>()) { return op->value == b.as<IntImm>()->value; } if (const UIntImm* op = a.as<UIntImm>()) { return op->value == b.as<UIntImm>()->value; } return false; } std::string attr_key_; NodeRef attr_node_; Expr attr_value_; }; Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { return AttrScopeLifter(attr_key).Lift(stmt); } } // namespace ir } // namespace tvm