unroll_loop.cc 6.36 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 *  Loop unrolling as in Halide pipeline.
22
 * \file unroll_loop.cc
23
 */
24
// Unrolls the loop as in Halide pipeline.
25 26
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
27
#include <tvm/ir_functor_ext.h>
28 29 30
#include <unordered_set>
#include <unordered_map>
#include <vector>
31
#include "../arithmetic/compute_expr.h"
32 33 34 35

namespace tvm {
namespace ir {

36
class LoopUnroller : public StmtExprMutator {
37
 public:
38
  explicit LoopUnroller(int auto_max_step,
39
                        int auto_max_depth,
40
                        int auto_max_extent,
41 42
                        bool explicit_unroll)
      : auto_max_step_(auto_max_step),
43
        auto_max_depth_(auto_max_depth),
44
        auto_max_extent_(auto_max_extent),
45
        explicit_unroll_(explicit_unroll) {
46 47
  }

48
  Stmt VisitStmt_(const AttrStmtNode* op) final {
49
    if (op->attr_key == "pragma_auto_unroll_max_step") {
50
      int value = 0;
51 52
      CHECK(arith::GetConstInt(op->value, &value));
      std::swap(value, auto_max_step_);
53
      Stmt ret = this->VisitStmt(op->body);
54 55 56
      std::swap(value, auto_max_step_);
      return ret;
    } else if (op->attr_key == "pragma_unroll_explicit") {
57
      int value = 0;
58 59 60
      CHECK(arith::GetConstInt(op->value, &value));
      bool explicit_unroll = value;
      std::swap(explicit_unroll, explicit_unroll_);
61
      Stmt ret = this->VisitStmt(op->body);
62 63 64
      std::swap(explicit_unroll, explicit_unroll_);
      return ret;
    } else {
65
      return StmtExprMutator::VisitStmt_(op);
66 67 68
    }
  }

69
  Stmt VisitStmt_(const ForNode* op) {
70
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
71
    op = stmt.as<ForNode>();
72
    int value = GetExtent(op);
73 74 75 76
    // condition for auto unroll
    bool auto_unroll = (
        op->for_type == ForType::Serial &&
        value >= 0 &&
77 78 79 80 81 82
        normal_loop_depth_ == 0 &&
        unroll_depth_ <= auto_max_depth_);

    auto_unroll = auto_unroll && (
        value * step_count_ <= auto_max_step_||
        value <= auto_max_extent_);
83

84 85 86
    if (op->for_type == ForType::Unrolled) {
      CHECK_GE(value, 0)
          << "Cannot unroll non-constant loop";
87
      auto_unroll = true;
88 89
    }

90 91 92 93 94 95 96
    if (auto_unroll) {
      step_count_  *=  value;
      unroll_depth_ += 1;
    } else {
      normal_loop_depth_ += 1;
    }

97 98
    if ((auto_unroll && explicit_unroll_) ||
        // unroll loops with extent = 1, no matter how many steps in body
99
        (0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) {
100
      return Unroll(op);
101
    } else {
102 103
      if (auto_unroll) {
        if (op->for_type != ForType::Unrolled) {
104
          return ForNode::make(
105 106 107 108
              op->loop_var, op->min, op->extent,
              ForType::Unrolled, op->device_api, op->body);
        }
      }
109 110 111 112
      return stmt;
    }
  }

113
  Stmt VisitStmt_(const StoreNode* op) final {
114
    ++step_count_;
115
    return StmtExprMutator::VisitStmt_(op);
116 117
  }

118
  Stmt VisitStmt_(const EvaluateNode* op) final {
119
    ++step_count_;
120
    return StmtExprMutator::VisitStmt_(op);
121 122
  }

123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
  Stmt VisitStmt_(const SeqStmtNode* op) final {
    auto fmutate = [this](const Stmt& s) {
      int step_count = step_count_;
      int unroll_depth = unroll_depth_;
      int normal_loop_depth = normal_loop_depth_;
      step_count_ = 0;
      unroll_depth_ = 0;
      normal_loop_depth_ = 0;
      Stmt ret = this->VisitStmt(s);
      step_count_ += step_count;
      normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
      unroll_depth_ = std::max(unroll_depth_, unroll_depth);
      return ret;
    };
    return StmtMutator::VisitSeqStmt_(op, false, fmutate);
138 139
  }

140
  Stmt Unroll(const ForNode* op) {
141 142 143
    int value = GetExtent(op);
    // For loop must have a constant integer extent
    CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
144
    if (value == 0) return EvaluateNode::make(0);
145
    Stmt body = op->body;
146
    Map<Var, PrimExpr> vmap;
147
    Array<Stmt> unrolled;
148
    for (int i = 0; i < value; ++i) {
149
      vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i));
150
      Stmt step = Substitute(body, vmap);
151
      unrolled.push_back(step);
152
    }
153
    return SeqStmt::Flatten(unrolled);
154 155
  }

156
 private:
157
  // returns the extent of the loop if it's a constant integer, otherwise return -1
158
  int GetExtent(const ForNode* op) {
159
    // constant folding.
160
    PrimExpr extent = ir::Simplify(op->extent);
161 162
    const IntImmNode  *v1 = extent.as<IntImmNode>();
    const UIntImmNode *v2 = extent.as<UIntImmNode>();
163 164 165 166 167 168 169 170 171 172
    int value = -1;
    if (v1 != nullptr) {
      value = static_cast<int>(v1->value);
    }
    if (v2 != nullptr) {
      value = static_cast<int>(v2->value);
    }
    return value;
  }

173 174
  // maximum number of step to perform auto unroll.
  int auto_max_step_;
175
  int auto_max_depth_;
176 177 178
  // max extent of loop to auto unroll
  // this not not count the total steps, only count the number of loops
  int auto_max_extent_;
179
  bool explicit_unroll_;
180 181 182 183 184 185
  // Number of normal loops in scope
  int normal_loop_depth_{0};
  // number of unrolled cases in current scope.
  int unroll_depth_{0};
  // Number of total steps unrolled
  int step_count_{0};
186 187 188
};


189 190
Stmt UnrollLoop(Stmt stmt,
                int auto_max_step,
191
                int auto_max_depth,
192
                int auto_max_extent,
193 194 195
                bool explicit_unroll) {
  Stmt ret = LoopUnroller(
      auto_max_step,
196
      auto_max_depth,
197
      auto_max_extent,
198
      explicit_unroll)(stmt);
199 200 201 202 203
  if (!ret.same_as(stmt)) {
    return ConvertSSA(ret);
  } else {
    return ret;
  }
204 205
}

206
Stmt UnrollLoopExplicitly(Stmt stmt) {
207
  const ForNode* op = stmt.as<ForNode>();
208 209 210 211 212 213
  if (!op) {
    LOG(FATAL) << "attempted to unroll a non-loop statement";
  }
  return LoopUnroller(0, 0, 0, false).Unroll(op);
}

214 215
}  // namespace ir
}  // namespace tvm