unroll_loop.cc 2.88 KB
Newer Older
1
/*!
2 3 4
 *  Copyright (c) 2017 by Contributors
 *  Loop unrolling.
 * \file unroll_loop.cc
5 6 7 8 9 10 11
 */
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
12
#include "../arithmetic/compute_expr.h"
13 14 15 16 17 18

namespace tvm {
namespace ir {

class LoopUnroller : public IRMutator {
 public:
19 20 21 22 23 24
  explicit LoopUnroller(int auto_max_step,
                        int auto_min_depth,
                        bool explicit_unroll)
      : auto_max_step_(auto_max_step),
        auto_min_depth_(auto_min_depth),
        explicit_unroll_(explicit_unroll) {
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
  }

  Stmt Mutate_(const For* op, const Stmt& s) {
    Stmt stmt = s;
    // constant folding.
    Expr extent = ir::Simplify(op->extent);
    const IntImm* v1 = extent.as<IntImm>();
    const UIntImm* v2 = extent.as<UIntImm>();
    int value = -1;
    if (v1 != nullptr) {
      value = static_cast<int>(v1->value);
    }
    if (v2 != nullptr) {
      value = static_cast<int>(v2->value);
    }
40 41 42
    bool auto_unroll = (op->for_type == ForType::Serial &&
                        value >= 0 && value <= auto_max_step_ &&
                        loop_depth_ >= auto_min_depth_);
43 44 45
    if (op->for_type == ForType::Unrolled) {
      CHECK_GE(value, 0)
          << "Cannot unroll non-constant loop";
46
      auto_unroll = true;
47 48
    }

49
    if (auto_unroll && explicit_unroll_) {
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
      using arith::ComputeExpr;
      if (value == 0) return Evaluate::make(0);
      Stmt body = op->body;
      Map<Var, Expr> vmap;
      Stmt unrolled;
      for (int i = 0; i < value; ++i) {
        Var lv(op->loop_var.node_);
        vmap.Set(lv,
                 ComputeExpr<Add>(
                     op->min, make_const(op->loop_var.type(), i)));
        Stmt step = Substitute(body, vmap);
        if (unrolled.defined()) {
          unrolled = Block::make(unrolled, step);
        } else {
          unrolled = step;
        }
      }
67 68 69 70
      ++loop_depth_;
      Stmt ret = this->Mutate(unrolled);
      --loop_depth_;
      return ret;
71
    } else {
72 73 74 75 76 77 78 79 80 81 82 83
      ++loop_depth_;
      Stmt ret = IRMutator::Mutate_(op, stmt);
      if (auto_unroll) {
        op = ret.as<For>();
        if (op->for_type != ForType::Unrolled) {
          ret = For::make(
              op->loop_var, op->min, op->extent,
              ForType::Unrolled, op->device_api, op->body);
        }
      }
      --loop_depth_;
      return ret;
84 85 86 87
    }
  }

 private:
88 89 90 91 92
  // maximum number of step to perform auto unroll.
  int auto_max_step_;
  int auto_min_depth_;
  bool explicit_unroll_;
  int loop_depth_{0};
93 94 95
};


96 97 98 99 100 101 102 103 104 105 106 107 108
Stmt UnrollLoop(Stmt stmt,
                int auto_max_step,
                int auto_min_depth,
                bool explicit_unroll) {
  Stmt ret = LoopUnroller(
      auto_max_step,
      auto_min_depth,
      explicit_unroll).Mutate(stmt);
  if (!ret.same_as(stmt)) {
    return ConvertSSA(ret);
  } else {
    return ret;
  }
109 110 111 112
}

}  // namespace ir
}  // namespace tvm