/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Loop unrolling as in Halide pipeline.
 * \file unroll_loop.cc
 */
// Unrolls the loop as in Halide pipeline.
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace ir {

class LoopUnroller : public IRMutator {
 public:
  explicit LoopUnroller(int auto_max_step,
                        int auto_max_depth,
                        int auto_max_extent,
                        bool explicit_unroll)
      : auto_max_step_(auto_max_step),
        auto_max_depth_(auto_max_depth),
        auto_max_extent_(auto_max_extent),
        explicit_unroll_(explicit_unroll) {
  }

  Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
    if (op->attr_key == "pragma_auto_unroll_max_step") {
      int value = 0;
      CHECK(arith::GetConstInt(op->value, &value));
      std::swap(value, auto_max_step_);
      Stmt ret = this->Mutate(op->body);
      std::swap(value, auto_max_step_);
      return ret;
    } else if (op->attr_key == "pragma_unroll_explicit") {
      int value = 0;
      CHECK(arith::GetConstInt(op->value, &value));
      bool explicit_unroll = value;
      std::swap(explicit_unroll, explicit_unroll_);
      Stmt ret = this->Mutate(op->body);
      std::swap(explicit_unroll, explicit_unroll_);
      return ret;
    } else {
      return IRMutator::Mutate_(op, stmt);
    }
  }

  Stmt Mutate_(const For* op, const Stmt& s) {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<For>();
    int value = GetExtent(op);
    // condition for auto unroll
    bool auto_unroll = (
        op->for_type == ForType::Serial &&
        value >= 0 &&
        normal_loop_depth_ == 0 &&
        unroll_depth_ <= auto_max_depth_);

    auto_unroll = auto_unroll && (
        value * step_count_ <= auto_max_step_||
        value <= auto_max_extent_);

    if (op->for_type == ForType::Unrolled) {
      CHECK_GE(value, 0)
          << "Cannot unroll non-constant loop";
      auto_unroll = true;
    }

    if (auto_unroll) {
      step_count_  *=  value;
      unroll_depth_ += 1;
    } else {
      normal_loop_depth_ += 1;
    }

    if ((auto_unroll && explicit_unroll_) ||
        // unroll loops with extent = 1, no matter how many steps in body
        (0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) {
      return Unroll(op);
    } else {
      if (auto_unroll) {
        if (op->for_type != ForType::Unrolled) {
          return For::make(
              op->loop_var, op->min, op->extent,
              ForType::Unrolled, op->device_api, op->body);
        }
      }
      return stmt;
    }
  }

  Stmt Mutate_(const Store* op, const Stmt& stmt) final {
    ++step_count_;
    return IRMutator::Mutate_(op, stmt);
  }

  Stmt Mutate_(const Evaluate* op, const Stmt& stmt) final {
    ++step_count_;
    return IRMutator::Mutate_(op, stmt);
  }

  Stmt Mutate_(const Block* op, const Stmt& stmt) final {
    Stmt first = this->Mutate(op->first);
    // cleanup state
    int step_count = step_count_;
    int unroll_depth = unroll_depth_;
    int normal_loop_depth = normal_loop_depth_;
    step_count_ = 0;
    unroll_depth_ = 0;
    normal_loop_depth_ = 0;
    // work on rest part
    Stmt rest = this->Mutate(op->rest);
    step_count_ += step_count;
    normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
    unroll_depth_ = std::max(unroll_depth_, unroll_depth);
    if (first.same_as(op->first) &&
        rest.same_as(op->rest)) {
      return stmt;
    } else {
      return Block::make(first, rest);
    }
  }

  Stmt Unroll(const For* op) {
    int value = GetExtent(op);
    // For loop must have a constant integer extent
    CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
    if (value == 0) return Evaluate::make(0);
    Stmt body = op->body;
    Map<Var, Expr> vmap;
    Stmt unrolled;
    for (int i = 0; i < value; ++i) {
      Var lv(op->loop_var.node_);
      vmap.Set(lv, op->min + make_const(op->loop_var.type(), i));
      Stmt step = Substitute(body, vmap);
      if (unrolled.defined()) {
        unrolled = Block::make(unrolled, step);
      } else {
        unrolled = step;
      }
    }
    return unrolled;
  }

 private:
  // returns the extent of the loop if it's a constant integer, otherwise return -1
  int GetExtent(const For* op) {
    // constant folding.
    Expr extent = ir::Simplify(op->extent);
    const IntImm  *v1 = extent.as<IntImm>();
    const UIntImm *v2 = extent.as<UIntImm>();
    int value = -1;
    if (v1 != nullptr) {
      value = static_cast<int>(v1->value);
    }
    if (v2 != nullptr) {
      value = static_cast<int>(v2->value);
    }
    return value;
  }

  // maximum number of step to perform auto unroll.
  int auto_max_step_;
  int auto_max_depth_;
  // max extent of loop to auto unroll
  // this not not count the total steps, only count the number of loops
  int auto_max_extent_;
  bool explicit_unroll_;
  // Number of normal loops in scope
  int normal_loop_depth_{0};
  // number of unrolled cases in current scope.
  int unroll_depth_{0};
  // Number of total steps unrolled
  int step_count_{0};
};


Stmt UnrollLoop(Stmt stmt,
                int auto_max_step,
                int auto_max_depth,
                int auto_max_extent,
                bool explicit_unroll) {
  Stmt ret = LoopUnroller(
      auto_max_step,
      auto_max_depth,
      auto_max_extent,
      explicit_unroll).Mutate(stmt);
  if (!ret.same_as(stmt)) {
    return ConvertSSA(ret);
  } else {
    return ret;
  }
}

Stmt UnrollLoopExplicitly(Stmt stmt) {
  const For* op = stmt.as<For>();
  if (!op) {
    LOG(FATAL) << "attempted to unroll a non-loop statement";
  }
  return LoopUnroller(0, 0, 0, false).Unroll(op);
}

}  // namespace ir
}  // namespace tvm