/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Loop unrolling as in Halide pipeline.
 * \file unroll_loop.cc
 */
// Unrolls the loop as in Halide pipeline.
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace ir {

class LoopUnroller : public StmtExprMutator {
 public:
  explicit LoopUnroller(int auto_max_step,
                        int auto_max_depth,
                        int auto_max_extent,
                        bool explicit_unroll)
      : auto_max_step_(auto_max_step),
        auto_max_depth_(auto_max_depth),
        auto_max_extent_(auto_max_extent),
        explicit_unroll_(explicit_unroll) {
  }

  Stmt VisitStmt_(const AttrStmtNode* op) final {
    if (op->attr_key == "pragma_auto_unroll_max_step") {
      int value = 0;
      CHECK(arith::GetConstInt(op->value, &value));
      std::swap(value, auto_max_step_);
      Stmt ret = this->VisitStmt(op->body);
      std::swap(value, auto_max_step_);
      return ret;
    } else if (op->attr_key == "pragma_unroll_explicit") {
      int value = 0;
      CHECK(arith::GetConstInt(op->value, &value));
      bool explicit_unroll = value;
      std::swap(explicit_unroll, explicit_unroll_);
      Stmt ret = this->VisitStmt(op->body);
      std::swap(explicit_unroll, explicit_unroll_);
      return ret;
    } else {
      return StmtExprMutator::VisitStmt_(op);
    }
  }

  Stmt VisitStmt_(const ForNode* op) {
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
    op = stmt.as<ForNode>();
    int value = GetExtent(op);
    // condition for auto unroll
    bool auto_unroll = (
        op->for_type == ForType::Serial &&
        value >= 0 &&
        normal_loop_depth_ == 0 &&
        unroll_depth_ <= auto_max_depth_);

    auto_unroll = auto_unroll && (
        value * step_count_ <= auto_max_step_||
        value <= auto_max_extent_);

    if (op->for_type == ForType::Unrolled) {
      CHECK_GE(value, 0)
          << "Cannot unroll non-constant loop";
      auto_unroll = true;
    }

    if (auto_unroll) {
      step_count_  *=  value;
      unroll_depth_ += 1;
    } else {
      normal_loop_depth_ += 1;
    }

    if ((auto_unroll && explicit_unroll_) ||
        // unroll loops with extent = 1, no matter how many steps in body
        (0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) {
      return Unroll(op);
    } else {
      if (auto_unroll) {
        if (op->for_type != ForType::Unrolled) {
          return ForNode::make(
              op->loop_var, op->min, op->extent,
              ForType::Unrolled, op->device_api, op->body);
        }
      }
      return stmt;
    }
  }

  Stmt VisitStmt_(const StoreNode* op) final {
    ++step_count_;
    return StmtExprMutator::VisitStmt_(op);
  }

  Stmt VisitStmt_(const EvaluateNode* op) final {
    ++step_count_;
    return StmtExprMutator::VisitStmt_(op);
  }

  Stmt VisitStmt_(const SeqStmtNode* op) final {
    auto fmutate = [this](const Stmt& s) {
      int step_count = step_count_;
      int unroll_depth = unroll_depth_;
      int normal_loop_depth = normal_loop_depth_;
      step_count_ = 0;
      unroll_depth_ = 0;
      normal_loop_depth_ = 0;
      Stmt ret = this->VisitStmt(s);
      step_count_ += step_count;
      normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
      unroll_depth_ = std::max(unroll_depth_, unroll_depth);
      return ret;
    };
    return StmtMutator::VisitSeqStmt_(op, false, fmutate);
  }

  Stmt Unroll(const ForNode* op) {
    int value = GetExtent(op);
    // For loop must have a constant integer extent
    CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
    if (value == 0) return EvaluateNode::make(0);
    Stmt body = op->body;
    Map<Var, PrimExpr> vmap;
    Array<Stmt> unrolled;
    for (int i = 0; i < value; ++i) {
      vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i));
      Stmt step = Substitute(body, vmap);
      unrolled.push_back(step);
    }
    return SeqStmt::Flatten(unrolled);
  }

 private:
  // returns the extent of the loop if it's a constant integer, otherwise return -1
  int GetExtent(const ForNode* op) {
    // constant folding.
    PrimExpr extent = ir::Simplify(op->extent);
    const IntImmNode  *v1 = extent.as<IntImmNode>();
    const UIntImmNode *v2 = extent.as<UIntImmNode>();
    int value = -1;
    if (v1 != nullptr) {
      value = static_cast<int>(v1->value);
    }
    if (v2 != nullptr) {
      value = static_cast<int>(v2->value);
    }
    return value;
  }

  // maximum number of step to perform auto unroll.
  int auto_max_step_;
  int auto_max_depth_;
  // max extent of loop to auto unroll
  // this not not count the total steps, only count the number of loops
  int auto_max_extent_;
  bool explicit_unroll_;
  // Number of normal loops in scope
  int normal_loop_depth_{0};
  // number of unrolled cases in current scope.
  int unroll_depth_{0};
  // Number of total steps unrolled
  int step_count_{0};
};


Stmt UnrollLoop(Stmt stmt,
                int auto_max_step,
                int auto_max_depth,
                int auto_max_extent,
                bool explicit_unroll) {
  Stmt ret = LoopUnroller(
      auto_max_step,
      auto_max_depth,
      auto_max_extent,
      explicit_unroll)(stmt);
  if (!ret.same_as(stmt)) {
    return ConvertSSA(ret);
  } else {
    return ret;
  }
}

Stmt UnrollLoopExplicitly(Stmt stmt) {
  const ForNode* op = stmt.as<ForNode>();
  if (!op) {
    LOG(FATAL) << "attempted to unroll a non-loop statement";
  }
  return LoopUnroller(0, 0, 0, false).Unroll(op);
}

}  // namespace ir
}  // namespace tvm