/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_virtual_thread.cc
 */
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <unordered_set>
#include "../../arith/compute_expr.h"

namespace tvm {
namespace tir {

// If expression is touched by var.
class ExprTouched final : public StmtExprVisitor {
 public:
  explicit ExprTouched(const std::unordered_set<const VarNode*> &touched,
                       bool check_write)
      : touched_var_(touched), check_write_(check_write) {}

  void VisitExpr(const PrimExpr& n) final {
    // early stopping
    if (expr_touched_ && !check_write_) return;
    StmtExprVisitor::VisitExpr(n);
  }
    void VisitStmt(const Stmt& n) final {
    // early stopping
    if (expr_touched_ && !check_write_) return;
    StmtExprVisitor::VisitStmt(n);
  }
  void VisitExpr_(const LoadNode *op) final {
    HandleUseVar(op->buffer_var.get());
    StmtExprVisitor::VisitExpr_(op);
  }
  void VisitExpr_(const VarNode *op) final {
    HandleUseVar(op);
  }
  void VisitExpr_(const CallNode *op) final {
    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
      int rw_mask = 0;
      CHECK(arith::GetConstInt(op->args[4], &rw_mask));
      const VarNode* buffer_var = op->args[1].as<VarNode>();
      CHECK(buffer_var);
      // read
      if (rw_mask & 1) {
        HandleUseVar(buffer_var);
      }
      if (rw_mask & 2) {
        HandleWriteVar(buffer_var);
      }
      this->VisitExpr(op->args[2]);
    } else {
      StmtExprVisitor::VisitExpr_(op);
    }
  }
  void HandleUseVar(const VarNode* var) {
    auto it = touched_var_.find(var);
    if (it != touched_var_.end()) {
      expr_touched_ = true;
    }
    // rember the used vars
    // in case the var get touched later in a loop.
    if (!expr_touched_) {
      used_vars_.push_back(var);
    }
  }
  void HandleWriteVar(const VarNode* var) {
    write_vars_.push_back(var);
  }
  // the fields.
  bool expr_touched_{false};
  std::vector<const VarNode*> used_vars_;
  std::vector<const VarNode*> write_vars_;
  const std::unordered_set<const VarNode*>& touched_var_;
  bool check_write_;
};

// Analyze if the buffers are invariant to value of var
class VarTouchedAnalysis : public StmtVisitor {
 public:
  void VisitStmt_(const LetStmtNode* op) final {
    ExprTouched tc(touched_var_, false);
    tc(op->value);
    Record(op->var.get(), tc);
    this->VisitStmt(op->body);
  }
  void VisitStmt_(const StoreNode* op) final {
    ExprTouched tc(touched_var_, false);
    tc(op->value);
    tc(op->index);
    Record(op->buffer_var.get(), tc);
  }
  void VisitStmt_(const ForNode* op) final {
    ExprTouched tc(touched_var_, false);
    tc(op->min);
    tc(op->extent);
    Record(op->loop_var.get(), tc);
    this->VisitStmt(op->body);
  }
  // external function call
  void VisitStmt_(const EvaluateNode* op) final {
    ExprTouched tc(touched_var_, true);
    tc(op->value);
    for (const VarNode* var : tc.write_vars_) {
      Record(var, tc);
    }
  }
  void VisitStmt_(const AllocateNode* op) final {
    ExprTouched tc(touched_var_, false);
    for (size_t i = 0; i < op->extents.size(); ++i) {
      tc(op->extents[i]);
    }
    tc.VisitExpr(op->condition);
    if (op->new_expr.defined()) {
      tc(op->new_expr);
    }
    Record(op->buffer_var.get(), tc);
    this->VisitStmt(op->body);
  }
  void Record(const VarNode* var,
              const ExprTouched& tc) {
    if (touched_var_.count(var)) return;
    if (tc.expr_touched_) {
      touched_var_.insert(var);
    } else {
      for (const VarNode* r : tc.used_vars_) {
        if (r != var) {
          affect_[r].push_back(var);
        }
      }
    }
  }

  std::unordered_set<const VarNode*>
  TouchedVar(const Stmt& stmt,
             const VarNode* var) {
    touched_var_.insert(var);
    this->VisitStmt(stmt);
    // do a DFS to push affect around dependency.
    std::vector<const VarNode*> pending(
        touched_var_.begin(), touched_var_.end());
    while (!pending.empty()) {
      const VarNode* v = pending.back();
      pending.pop_back();
      for (const VarNode* r : affect_[v]) {
        if (!touched_var_.count(r)) {
          touched_var_.insert(r);
          pending.push_back(r);
        }
      }
    }
    return std::move(touched_var_);
  }

 private:
  // Whether variable is touched by the thread variable.
  std::unordered_set<const VarNode*> touched_var_;
  // x -> all the buffers x read from
  std::unordered_map<const VarNode*,
                     std::vector<const VarNode*> > affect_;
};


// Inject virtual thread loop
// rewrite the buffer access pattern when necessary.
class VTInjector : public StmtExprMutator {
 public:
  // constructor
  VTInjector(Var var,
             int num_threads,
             const std::unordered_set<const VarNode*>& touched_var,
             bool allow_share)
      : var_(var), num_threads_(num_threads),
        touched_var_(touched_var), allow_share_(allow_share) {
  }
  // Inject VTLoop when needed.
  Stmt VisitStmt(const Stmt& s) final {
    CHECK(!visit_touched_var_);
    auto stmt = StmtExprMutator::VisitStmt(s);
    if (visit_touched_var_ || trigger_base_inject_) {
      if (!vt_loop_injected_)  {
        return InjectVTLoop(stmt, false);
      }
      visit_touched_var_ = false;
      trigger_base_inject_ = false;
    }
    return stmt;
  }
  // Variable
  PrimExpr VisitExpr_(const VarNode* op) final {
    CHECK(!alloc_remap_.count(op))
        << "Buffer address may get rewritten in virtual thread";
    if (touched_var_.count(op)) {
      visit_touched_var_ = true;
    }
    return GetRef<PrimExpr>(op);
  }
  PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const {
    return index + var_ * alloc_extent;
  }
  // Load
  PrimExpr VisitExpr_(const LoadNode* op) final {
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
    op = expr.as<LoadNode>();
    if (touched_var_.count(op->buffer_var.get())) {
      visit_touched_var_ = true;
    }
    auto it = alloc_remap_.find(op->buffer_var.get());
    if (it != alloc_remap_.end()) {
      return LoadNode::make(op->dtype, op->buffer_var,
                        RewriteIndex(op->index, it->second),
                        op->predicate);
    } else {
      return expr;
    }
  }
  // Expression.
  PrimExpr VisitExpr_(const CallNode* op) final {
    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
      CHECK_EQ(op->args.size(), 5U);
      DataType dtype = op->args[0].dtype();
      const VarNode* buffer = op->args[1].as<VarNode>();
      auto it = alloc_remap_.find(buffer);
      if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op);
      visit_touched_var_ = true;
      PrimExpr offset = this->VisitExpr(op->args[2]);
      PrimExpr extent = this->VisitExpr(op->args[3]);
      PrimExpr stride =
          it->second / make_const(offset.dtype(), dtype.lanes());
      offset = stride * var_ + offset;
      return CallNode::make(
          op->dtype, op->name,
          {op->args[0], op->args[1], offset, extent, op->args[4]},
          op->call_type);
    } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
      return allow_share_ ? GetRef<PrimExpr>(op) : var_;
    } else {
      return StmtExprMutator::VisitExpr_(op);
    }
  }
  Stmt VisitStmt_(const EvaluateNode* op) final {
    trigger_base_inject_ = !allow_share_;
    return StmtExprMutator::VisitStmt_(op);
  }
  // Store
  Stmt VisitStmt_(const StoreNode* op) final {
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
    op = stmt.as<StoreNode>();
    if (touched_var_.count(op->buffer_var.get())) {
      visit_touched_var_ = true;
    }
    trigger_base_inject_ = !allow_share_;
    auto it = alloc_remap_.find(op->buffer_var.get());
    if (it != alloc_remap_.end()) {
      return StoreNode::make(op->buffer_var,
                         op->value,
                         RewriteIndex(op->index, it->second),
                         op->predicate);
    } else {
      return stmt;
    }
  }
  // Attribute
  Stmt VisitStmt_(const AttrStmtNode* op) final {
    PrimExpr value = this->VisitExpr(op->value);
    if (visit_touched_var_ && !vt_loop_injected_) {
      return InjectVTLoop(GetRef<Stmt>(op), true);
    } else if (!allow_share_ && !vt_loop_injected_ &&
               (op->attr_key == attr::coproc_uop_scope ||
                op->attr_key == attr::coproc_scope)) {
      return InjectVTLoop(GetRef<Stmt>(op), true);
    } else {
      Stmt body = this->VisitStmt(op->body);
      if (value.same_as(op->value) &&
          body.same_as(op->body)) {
        return GetRef<Stmt>(op);
      } else {
        return AttrStmtNode::make(op->node, op->attr_key, value, body);
      }
    }
  }
  // LetStmt
  Stmt VisitStmt_(const LetStmtNode* op) final {
    PrimExpr value = this->VisitExpr(op->value);
    if (visit_touched_var_ && !vt_loop_injected_) {
      return InjectVTLoop(GetRef<Stmt>(op), true);
    }
    visit_touched_var_ = false;
    Stmt body = this->VisitStmt(op->body);
    if (value.same_as(op->value) &&
        body.same_as(op->body)) {
      return GetRef<Stmt>(op);
    } else {
      return LetStmtNode::make(op->var, value, body);
    }
  }
  // For
  Stmt VisitStmt_(const ForNode* op) final {
    CHECK(is_zero(op->min));
    PrimExpr extent = this->VisitExpr(op->extent);
    if (visit_touched_var_ && !vt_loop_injected_) {
      Stmt stmt = InjectVTLoop(GetRef<Stmt>(op), true);
      ++max_loop_depth_;
      return stmt;
    }
    visit_touched_var_ = false;
    Stmt body = this->VisitStmt(op->body);
    ++max_loop_depth_;
    if (extent.same_as(op->extent) &&
        body.same_as(op->body)) {
      return GetRef<Stmt>(op);
    } else {
      return ForNode::make(
          op->loop_var, op->min, extent, op->for_type, op->device_api, body);
    }
  }
  // IfThenElse
  Stmt VisitStmt_(const IfThenElseNode* op) final {
    PrimExpr condition = this->VisitExpr(op->condition);
    if (visit_touched_var_ && !vt_loop_injected_) {
      return InjectVTLoop(GetRef<Stmt>(op), true);
    }
    visit_touched_var_ = false;
    CHECK_EQ(max_loop_depth_, 0);
    Stmt then_case = this->VisitStmt(op->then_case);
    Stmt else_case;
    if (op->else_case.defined()) {
      int temp = max_loop_depth_;
      max_loop_depth_ = 0;
      else_case = this->VisitStmt(op->else_case);
      max_loop_depth_ = std::max(temp, max_loop_depth_);
    }
    if (condition.same_as(op->condition) &&
        then_case.same_as(op->then_case) &&
        else_case.same_as(op->else_case)) {
      return GetRef<Stmt>(op);
    } else {
      return IfThenElseNode::make(condition, then_case, else_case);
    }
  }

  // Seq
  Stmt VisitStmt_(const SeqStmtNode* op) final {
    CHECK_EQ(max_loop_depth_, 0);
    auto fmutate = [this](const Stmt& s) {
      int temp = max_loop_depth_;
      max_loop_depth_ = 0;
      Stmt ret = this->VisitStmt(s);
      max_loop_depth_ = std::max(max_loop_depth_, temp);
      return ret;
    };
    return StmtMutator::VisitSeqStmt_(op, false, fmutate);
  }
  // Allocate
  Stmt VisitStmt_(const AllocateNode* op) final {
    if (op->new_expr.defined() && !vt_loop_injected_) {
      return InjectVTLoop(GetRef<Stmt>(op), true);
    }
    PrimExpr condition = this->VisitExpr(op->condition);
    if (visit_touched_var_ && !vt_loop_injected_) {
      return InjectVTLoop(GetRef<Stmt>(op), true);
    }

    bool changed = false;
    Array<PrimExpr> extents;
    for (size_t i = 0; i < op->extents.size(); i++) {
      PrimExpr new_ext = this->VisitExpr(op->extents[i]);
      if (visit_touched_var_ && !vt_loop_injected_) {
        return InjectVTLoop(GetRef<Stmt>(op), true);
      }
      if (!new_ext.same_as(op->extents[i])) changed = true;
      extents.push_back(new_ext);
    }
    visit_touched_var_ = false;

    Stmt body;
    // always rewrite if not allow sharing.
    if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
      // place v on highest dimension.
      PrimExpr stride = arith::ComputeReduce<MulNode>(
          op->extents, PrimExpr()) * op->dtype.lanes();
      Array<PrimExpr> other;
      other.push_back(make_const(op->extents[0].dtype(), num_threads_));
      for (PrimExpr e : extents) {
        other.push_back(e);
      }
      extents = other;
      changed = true;
      // mark this buffer get touched.
      alloc_remap_[op->buffer_var.get()] = stride;
      // Mutate the body.
      body = this->VisitStmt(op->body);
    } else {
      // Mutate the body.
      body = this->VisitStmt(op->body);
    }
    if (!changed &&
        body.same_as(op->body) &&
        condition.same_as(op->condition)) {
      return GetRef<Stmt>(op);
    } else {
      return AllocateNode::make(
          op->buffer_var, op->dtype,
          extents, condition, body,
          op->new_expr, op->free_function);
    }
  }

  // inject vthread loop
  Stmt InjectVTLoop(Stmt stmt, bool before_mutation) {
    CHECK(!vt_loop_injected_);
    // reset the flags
    visit_touched_var_ = false;
    trigger_base_inject_ = false;
    vt_loop_injected_ = true;
    if (before_mutation) {
      stmt = this->VisitStmt(stmt);
    }
    // reset the flags after processing.
    vt_loop_injected_ = false;
    visit_touched_var_ = false;
    // only unroll if number of vthreads are small
    if (max_loop_depth_ == 0 && num_threads_ < 16) {
      // do unrolling if it is inside innermost content.
      Array<Stmt> seq;
      for (int i = 0; i < num_threads_; ++i) {
        seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}}));
      }
      return SeqStmt::Flatten(seq);
    } else {
      // insert a for loop
      Var idx(var_->name_hint + ".s", var_->dtype);
      Map<Var, PrimExpr> values{{var_, idx}};
      stmt = Substitute(stmt, values);
      return ForNode::make(idx, make_zero(idx.dtype()),
                       make_const(idx.dtype(), num_threads_),
                       ForType::Serial, DeviceAPI::None, stmt);
    }
  }

 private:
  // vthread variable
  Var var_;
  // the threads/lanes
  int num_threads_;
  // whethe the loop is already injected.
  bool vt_loop_injected_{false};
  // whether current expression get touched.
  bool visit_touched_var_{false};
  // Trigger base stmt
  bool trigger_base_inject_{false};
  // the counter of loops in after mutation.
  int max_loop_depth_{0};
  // The variables that get touched.
  const std::unordered_set<const VarNode*>& touched_var_;
  // Whether allow shareding.
  bool allow_share_;
  // The allocations that get touched -> extent
  std::unordered_map<const VarNode*, PrimExpr> alloc_remap_;
};


class VirtualThreadInjector : public StmtMutator {
 public:
  Stmt VisitStmt_(const AttrStmtNode* op) final {
    Stmt stmt = StmtMutator::VisitStmt_(op);
    op = stmt.as<AttrStmtNode>();
    if (op->attr_key == attr::virtual_thread) {
      IterVar iv = Downcast<IterVar>(op->node);
      bool allow_share = iv->thread_tag == "vthread";
      int nthread = static_cast<int>(op->value.as<IntImmNode>()->value);
      VarTouchedAnalysis vs;
      auto touched = vs.TouchedVar(op->body, iv->var.get());
      VTInjector injecter(iv->var, nthread, touched, allow_share);
      return injecter(op->body);
    } else {
      return stmt;
    }
  }

  Stmt VisitStmt_(const ProvideNode* op) final {
    LOG(FATAL) << "Need to call StorageFlatten first";
    return GetRef<Stmt>(op);
  }
};

Stmt InjectVirtualThread(Stmt stmt) {
  stmt = VirtualThreadInjector()(std::move(stmt));
  return ConvertSSA(std::move(stmt));
}

}  // namespace tir
}  // namespace tvm