lift_attr_scope.cc 5.87 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25
/*!
 *
 * \brief Lift specified AttrStmt scope to outer if
 *   the body contains the same scope.
 * \file lift_attr_scope.cc
 */
26 27
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
28
#include "ir_util.h"
29 30

namespace tvm {
31
namespace tir {
32 33 34

// NOTE: this optimization can only be applied
// to a few specified attr keys
35
class AttrScopeLifter : public StmtMutator {
36 37 38 39 40
 public:
  explicit AttrScopeLifter(std::string attr_key)
      : attr_key_(attr_key) {}

  Stmt Lift(Stmt stmt) {
41
    stmt = operator()(std::move(stmt));
42
    if (attr_node_.defined()) {
43
      stmt = AttrStmtNode::make(
44 45 46 47 48 49
          attr_node_, attr_key_, attr_value_, stmt);
    }
    return stmt;
  }

  // do not go beyond
50
  Stmt VisitStmt_(const AllocateNode* op) final {
51
    Stmt stmt = StmtMutator::VisitStmt_(op);
52
    op = stmt.as<AllocateNode>();
53
    if (attr_node_.defined()) {
54
      Stmt body = AttrStmtNode::make(
55 56
          attr_node_, attr_key_, attr_value_, op->body);
      // undefine them
57
      attr_node_ = ObjectRef();
58
      attr_value_ = PrimExpr();
59
      return AllocateNode::make(
60
        op->buffer_var, op->dtype,
61 62 63 64 65 66 67
        op->extents, op->condition, body,
        op->new_expr, op->free_function);
    } else {
      return stmt;
    }
  }

68
  Stmt VisitStmt_(const AttrStmtNode* op) final {
69 70 71 72 73
    if (op->attr_key == attr_key_) {
      attr_node_ = op->node;
      attr_value_ = op->value;
      return op->body;
    } else {
74
      return StmtMutator::VisitStmt_(op);
75 76 77
    }
  }

78 79 80
  Stmt VisitStmt_(const SeqStmtNode* op) final {
    // remember the decorations.
    std::vector<ObjectRef> attr_node;
81
    std::vector<PrimExpr> attr_value;
82 83 84

    auto fmutate = [&](const Stmt& s) {
      attr_node_ = ObjectRef();
85
      attr_value_ = PrimExpr();
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
      Stmt ret = this->VisitStmt(s);
      attr_node.push_back(attr_node_);
      attr_value.push_back(attr_value_);
      return ret;
    };
    Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate);
    if (attr_node.size() == 0) return ret;

    op = ret.as<SeqStmtNode>();
    CHECK(op != nullptr);
    Array<Stmt> reorg;
    // check if all decorations are common.
    for (size_t begin = 0; begin < attr_node.size();) {
      size_t end = begin + 1;
      while (end < attr_node.size() &&
             attr_node[end].same_as(attr_node[begin]) &&
             ValueSame(attr_value[end], attr_value[begin])) {
        ++end;
      }
      // covers everything
      // lift attr to parent.
      if (begin == 0 && end == attr_node.size()) {
        attr_node_ = attr_node[0];
        attr_value_ = attr_value[0];
        return ret;
      }
      // construct subsegments.
      Array<Stmt> seq;
      for (size_t i = begin; i < end; ++i) {
        seq.push_back(op->seq[i]);
      }
      Stmt stmt = SeqStmt::Flatten(seq);
      if (attr_node[begin].defined()) {
119
        stmt = AttrStmtNode::make(
120 121 122 123
            attr_node[begin], attr_key_, attr_value[begin], stmt);
      }
      reorg.push_back(stmt);
      begin = end;
124
    }
125
    attr_node_ = ObjectRef();
126
    attr_value_ = PrimExpr();
127
    return SeqStmt::Flatten(reorg);
128 129
  }

130
  Stmt VisitStmt_(const IfThenElseNode* op) final {
131
    if (!op->else_case.defined()) {
132
      return StmtMutator::VisitStmt_(op);
133
    }
134
    Stmt then_case = this->VisitStmt(op->then_case);
135
    ObjectRef first_node;
136
    PrimExpr first_value;
137 138
    std::swap(first_node, attr_node_);
    std::swap(first_value, attr_value_);
139
    Stmt else_case = this->VisitStmt(op->else_case);
140 141
    if (attr_node_.defined() &&
        attr_value_.defined() &&
142 143 144 145
        first_node.defined() &&
        first_value.defined() &&
        attr_node_.same_as(first_node) &&
        ValueSame(attr_value_, first_value)) {
146 147
      if (then_case.same_as(op->then_case) &&
          else_case.same_as(op->else_case)) {
148
        return GetRef<Stmt>(op);
149
      } else {
150
        return IfThenElseNode::make(op->condition, then_case, else_case);
151 152
      }
    } else {
153
      if (first_node.defined()) {
154
        then_case = AttrStmtNode::make(
155
            first_node, attr_key_, first_value, then_case);
156 157
      }
      if (attr_node_.defined()) {
158
        else_case = AttrStmtNode::make(
159 160
            attr_node_, attr_key_, attr_value_, else_case);
        // undefine them
161
        attr_node_ = ObjectRef();
162
        attr_value_ = PrimExpr();
163 164 165
      }
      if (then_case.same_as(op->then_case) &&
          else_case.same_as(op->else_case)) {
166
        return GetRef<Stmt>(op);
167
      } else {
168
        return IfThenElseNode::make(op->condition, then_case, else_case);
169 170 171 172 173
      }
    }
  }

 private:
174
  // value comparison that also compares content of int constant
175
  static bool ValueSame(const PrimExpr& a, const PrimExpr& b) {
176
    if (a.same_as(b)) return true;
177
    if (!a.defined() || !b.defined()) return false;
178
    if (a->type_index() != b->type_index()) return false;
179
    if (a.dtype() != b.dtype()) return false;
180 181
    if (const IntImmNode* op = a.as<IntImmNode>()) {
      return op->value == b.as<IntImmNode>()->value;
182 183 184 185
    }
    return false;
  }

186
  std::string attr_key_;
187
  ObjectRef attr_node_;
188
  PrimExpr attr_value_;
189 190 191
};

Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
192
  return AttrScopeLifter(attr_key).Lift(std::move(stmt));
193 194
}

195
}  // namespace tir
196
}  // namespace tvm