unroll_loop.cc 5.69 KB
Newer Older
1
/*!
2
 *  Copyright (c) 2017 by Contributors
3
 *  Loop unrolling as in Halide pipeline.
4
 * \file unroll_loop.cc
5
 */
6
// Unrolls the loop as in Halide pipeline.
7 8 9 10 11 12
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
13
#include "../arithmetic/compute_expr.h"
14 15 16 17 18 19

namespace tvm {
namespace ir {

class LoopUnroller : public IRMutator {
 public:
20
  explicit LoopUnroller(int auto_max_step,
21
                        int auto_max_depth,
22
                        int auto_max_extent,
23 24
                        bool explicit_unroll)
      : auto_max_step_(auto_max_step),
25
        auto_max_depth_(auto_max_depth),
26
        auto_max_extent_(auto_max_extent),
27
        explicit_unroll_(explicit_unroll) {
28 29
  }

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
  Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
    if (op->attr_key == "pragma_auto_unroll_max_step") {
      int value;
      CHECK(arith::GetConstInt(op->value, &value));
      std::swap(value, auto_max_step_);
      Stmt ret = this->Mutate(op->body);
      std::swap(value, auto_max_step_);
      return ret;
    } else if (op->attr_key == "pragma_unroll_explicit") {
      int value;
      CHECK(arith::GetConstInt(op->value, &value));
      bool explicit_unroll = value;
      std::swap(explicit_unroll, explicit_unroll_);
      Stmt ret = this->Mutate(op->body);
      std::swap(explicit_unroll, explicit_unroll_);
      return ret;
    } else {
      return IRMutator::Mutate_(op, stmt);
    }
  }

51
  Stmt Mutate_(const For* op, const Stmt& s) {
52 53
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<For>();
54
    int value = GetExtent(op);
55 56 57 58
    // condition for auto unroll
    bool auto_unroll = (
        op->for_type == ForType::Serial &&
        value >= 0 &&
59 60 61 62 63 64
        normal_loop_depth_ == 0 &&
        unroll_depth_ <= auto_max_depth_);

    auto_unroll = auto_unroll && (
        value * step_count_ <= auto_max_step_||
        value <= auto_max_extent_);
65

66 67 68
    if (op->for_type == ForType::Unrolled) {
      CHECK_GE(value, 0)
          << "Cannot unroll non-constant loop";
69
      auto_unroll = true;
70 71
    }

72 73 74 75 76 77 78
    if (auto_unroll) {
      step_count_  *=  value;
      unroll_depth_ += 1;
    } else {
      normal_loop_depth_ += 1;
    }

79
    if (auto_unroll && explicit_unroll_) {
80
      return Unroll(op);
81
    } else {
82 83
      if (auto_unroll) {
        if (op->for_type != ForType::Unrolled) {
84
          return For::make(
85 86 87 88
              op->loop_var, op->min, op->extent,
              ForType::Unrolled, op->device_api, op->body);
        }
      }
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
      return stmt;
    }
  }

  Stmt Mutate_(const Store* op, const Stmt& stmt) final {
    ++step_count_;
    return IRMutator::Mutate_(op, stmt);
  }

  Stmt Mutate_(const Evaluate* op, const Stmt& stmt) final {
    ++step_count_;
    return IRMutator::Mutate_(op, stmt);
  }

  Stmt Mutate_(const Block* op, const Stmt& stmt) final {
    Stmt first = this->Mutate(op->first);
    // cleanup state
    int step_count = step_count_;
    int unroll_depth = unroll_depth_;
    int normal_loop_depth = normal_loop_depth_;
    step_count_ = 0;
    unroll_depth_ = 0;
    normal_loop_depth_ = 0;
    // work on rest part
    Stmt rest = this->Mutate(op->rest);
    step_count_ += step_count;
    normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
    unroll_depth_ = std::max(unroll_depth_, unroll_depth);
    if (first.same_as(op->first) &&
        rest.same_as(op->rest)) {
      return stmt;
    } else {
      return Block::make(first, rest);
122 123 124
    }
  }

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
  Stmt Unroll(const For* op) {
    using arith::ComputeExpr;
    int value = GetExtent(op);
    // For loop must have a constant integer extent
    CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
    if (value == 0) return Evaluate::make(0);
    Stmt body = op->body;
    Map<Var, Expr> vmap;
    Stmt unrolled;
    for (int i = 0; i < value; ++i) {
      Var lv(op->loop_var.node_);
      vmap.Set(lv,
               ComputeExpr<Add>(
                       op->min, make_const(op->loop_var.type(), i)));
      Stmt step = Substitute(body, vmap);
      if (unrolled.defined()) {
        unrolled = Block::make(unrolled, step);
      } else {
        unrolled = step;
      }
    }
    return unrolled;
  }

149
 private:
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
  // returns the extent of the loop if it's a constant integer, otherwise return -1
  int GetExtent(const For* op) {
    // constant folding.
    Expr extent = ir::Simplify(op->extent);
    const IntImm  *v1 = extent.as<IntImm>();
    const UIntImm *v2 = extent.as<UIntImm>();
    int value = -1;
    if (v1 != nullptr) {
      value = static_cast<int>(v1->value);
    }
    if (v2 != nullptr) {
      value = static_cast<int>(v2->value);
    }
    return value;
  }

166 167
  // maximum number of step to perform auto unroll.
  int auto_max_step_;
168
  int auto_max_depth_;
169 170 171
  // max extent of loop to auto unroll
  // this not not count the total steps, only count the number of loops
  int auto_max_extent_;
172
  bool explicit_unroll_;
173 174 175 176 177 178
  // Number of normal loops in scope
  int normal_loop_depth_{0};
  // number of unrolled cases in current scope.
  int unroll_depth_{0};
  // Number of total steps unrolled
  int step_count_{0};
179 180 181
};


182 183
Stmt UnrollLoop(Stmt stmt,
                int auto_max_step,
184
                int auto_max_depth,
185
                int auto_max_extent,
186 187 188
                bool explicit_unroll) {
  Stmt ret = LoopUnroller(
      auto_max_step,
189
      auto_max_depth,
190
      auto_max_extent,
191 192 193 194 195 196
      explicit_unroll).Mutate(stmt);
  if (!ret.same_as(stmt)) {
    return ConvertSSA(ret);
  } else {
    return ret;
  }
197 198
}

199 200 201 202 203 204 205 206
Stmt UnrollLoopExplicitly(Stmt stmt) {
  const For* op = stmt.as<For>();
  if (!op) {
    LOG(FATAL) << "attempted to unroll a non-loop statement";
  }
  return LoopUnroller(0, 0, 0, false).Unroll(op);
}

207 208
}  // namespace ir
}  // namespace tvm