/*!
 *  Copyright (c) 2017 by Contributors
 * \file inject_virtual_thread.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace ir {

// If expression is touched by var.
class ExprTouched final : public IRVisitor {
 public:
  explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
                       bool check_write)
      : touched_var_(touched), check_write_(check_write) {}
  void Visit(const NodeRef& n) final {
    // early stopping
    if (expr_touched_ && !check_write_) return;
    IRVisitor::Visit(n);
  }
  void Visit_(const Load *op) final {
    HandleUseVar(op->buffer_var.get());
    IRVisitor::Visit_(op);
  }
  void Visit_(const Variable *op) final {
    HandleUseVar(op);
  }
  void Visit_(const Call *op) final {
    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
      int rw_mask = 0;
      CHECK(arith::GetConstInt(op->args[4], &rw_mask));
      const Variable* buffer_var = op->args[1].as<Variable>();
      CHECK(buffer_var);
      // read
      if (rw_mask & 1) {
        HandleUseVar(buffer_var);
      }
      if (rw_mask & 2) {
        HandleWriteVar(buffer_var);
      }
      this->Visit(op->args[2]);
    } else {
      IRVisitor::Visit_(op);
    }
  }
  void HandleUseVar(const Variable* var) {
    auto it = touched_var_.find(var);
    if (it != touched_var_.end()) {
      expr_touched_ = true;
    }
    // rember the used vars
    // in case the var get touched later in a loop.
    if (!expr_touched_) {
      used_vars_.push_back(var);
    }
  }
  void HandleWriteVar(const Variable* var) {
    write_vars_.push_back(var);
  }
  // the fields.
  bool expr_touched_{false};
  std::vector<const Variable*> used_vars_;
  std::vector<const Variable*> write_vars_;
  const std::unordered_set<const Variable*>& touched_var_;
  bool check_write_;
};

// Analyze if the buffers are invariant to value of var
class VarTouchedAnalysis : public IRVisitor {
 public:
  void Visit_(const LetStmt *op) {
    ExprTouched tc(touched_var_, false);
    tc.Visit(op->value);
    Record(op->var.get(), tc);
    this->Visit(op->body);
  }
  void Visit_(const Store *op) {
    ExprTouched tc(touched_var_, false);
    tc.Visit(op->value);
    tc.Visit(op->index);
    Record(op->buffer_var.get(), tc);
  }
  void Visit_(const For *op) {
    ExprTouched tc(touched_var_, false);
    tc.Visit(op->min);
    tc.Visit(op->extent);
    Record(op->loop_var.get(), tc);
    this->Visit(op->body);
  }
  // external function call
  void Visit_(const Evaluate *op) {
    ExprTouched tc(touched_var_, true);
    tc.Visit(op->value);
    for (const Variable* var : tc.write_vars_) {
      Record(var, tc);
    }
  }
  void Visit_(const Allocate *op) {
    ExprTouched tc(touched_var_, false);
    for (size_t i = 0; i < op->extents.size(); ++i) {
      tc.Visit(op->extents[i]);
    }
    tc.Visit(op->condition);
    if (op->new_expr.defined()) {
      tc.Visit(op->new_expr);
    }
    Record(op->buffer_var.get(), tc);
    this->Visit(op->body);
  }
  void Record(const Variable* var,
              const ExprTouched& tc) {
    if (touched_var_.count(var)) return;
    if (tc.expr_touched_) {
      touched_var_.insert(var);
    } else {
      for (const Variable* r : tc.used_vars_) {
        if (r != var) {
          affect_[r].push_back(var);
        }
      }
    }
  }

  std::unordered_set<const Variable*>
  TouchedVar(const Stmt& stmt,
             const Variable* var) {
    touched_var_.insert(var);
    this->Visit(stmt);
    // do a DFS to push affect around dependency.
    std::vector<const Variable*> pending(
        touched_var_.begin(), touched_var_.end());
    while (!pending.empty()) {
      const Variable* v = pending.back();
      pending.pop_back();
      for (const Variable* r : affect_[v]) {
        if (!touched_var_.count(r)) {
          touched_var_.insert(r);
          pending.push_back(r);
        }
      }
    }
    return std::move(touched_var_);
  }

 private:
  // Whether variable is touched by the thread variable.
  std::unordered_set<const Variable*> touched_var_;
  // x -> all the buffers x read from
  std::unordered_map<const Variable*,
                     std::vector<const Variable*> > affect_;
};


// Inject virtual thread loop
// rewrite the buffer access pattern when necessary.
class VTInjector : public IRMutator {
 public:
  using IRMutator::Mutate;
  // constructor
  VTInjector(Var var,
             int num_threads,
             const std::unordered_set<const Variable*>& touched_var,
             bool allow_share)
      : var_(var), num_threads_(num_threads),
        touched_var_(touched_var), allow_share_(allow_share) {
  }
  // Inject VTLoop when needed.
  Stmt Mutate(Stmt stmt) final {
    CHECK(!visit_touched_var_)
        << stmt->type_key() << stmt;
    stmt = IRMutator::Mutate(stmt);
    if (visit_touched_var_ || trigger_base_inject_) {
      if (!vt_loop_injected_)  {
        return InjectVTLoop(stmt, false);
      }
      visit_touched_var_ = false;
      trigger_base_inject_ = false;
    }
    return stmt;
  }
  // Variable
  Expr Mutate_(const Variable *op, const Expr& e) final {
    CHECK(!alloc_remap_.count(op))
        << "Buffer address may get rewritten in virtual thread";
    if (touched_var_.count(op)) {
      visit_touched_var_ = true;
    }
    return e;
  }
  Expr RewriteIndex(Expr index, Expr alloc_extent) const {
    return index + var_ * alloc_extent;
  }
  // Load
  Expr Mutate_(const Load* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Load>();
    if (touched_var_.count(op->buffer_var.get())) {
      visit_touched_var_ = true;
    }
    auto it = alloc_remap_.find(op->buffer_var.get());
    if (it != alloc_remap_.end()) {
      return Load::make(op->type, op->buffer_var,
                        RewriteIndex(op->index, it->second),
                        op->predicate);
    } else {
      return expr;
    }
  }
  // Expression.
  Expr Mutate_(const Call* op, const Expr& e) final {
    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
      CHECK_EQ(op->args.size(), 5U);
      Type dtype = op->args[0].type();
      const Variable* buffer = op->args[1].as<Variable>();
      auto it = alloc_remap_.find(buffer);
      if (it == alloc_remap_.end()) return IRMutator::Mutate_(op, e);
      visit_touched_var_ = true;
      Expr offset = Mutate(op->args[2]);
      Expr extent = Mutate(op->args[3]);
      Expr stride = arith::ComputeExpr<Div>(
          it->second, make_const(offset.type(), dtype.lanes()));
      offset = stride * var_ + offset;
      return Call::make(
          op->type, op->name,
          {op->args[0], op->args[1], offset, extent, op->args[4]},
          op->call_type);
    } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
      return allow_share_ ? e : var_;
    } else {
      return IRMutator::Mutate_(op, e);
    }
  }
  Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
    trigger_base_inject_ = !allow_share_;
    return IRMutator::Mutate_(op, s);
  }
  // Store
  Stmt Mutate_(const Store* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Store>();
    if (touched_var_.count(op->buffer_var.get())) {
      visit_touched_var_ = true;
    }
    trigger_base_inject_ = !allow_share_;
    auto it = alloc_remap_.find(op->buffer_var.get());
    if (it != alloc_remap_.end()) {
      return Store::make(op->buffer_var,
                         op->value,
                         RewriteIndex(op->index, it->second),
                         op->predicate);
    } else {
      return stmt;
    }
  }
  // Attribute
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    Expr value = Mutate(op->value);
    if (visit_touched_var_ && !vt_loop_injected_) {
      return InjectVTLoop(s, true);
    } else if (!allow_share_ && !vt_loop_injected_ &&
               (op->attr_key == attr::coproc_uop_scope ||
                op->attr_key == attr::coproc_scope)) {
      return InjectVTLoop(s, true);
    } else {
      Stmt body = Mutate(op->body);
      if (value.same_as(op->value) &&
          body.same_as(op->body)) {
        return s;
      } else {
        return AttrStmt::make(op->node, op->attr_key, value, body);
      }
    }
  }
  // LetStmt
  Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
    Expr value = this->Mutate(op->value);
    if (visit_touched_var_ && !vt_loop_injected_) {
      return InjectVTLoop(s, true);
    }
    visit_touched_var_ = false;
    Stmt body = Mutate(op->body);
    if (value.same_as(op->value) &&
        body.same_as(op->body)) {
      return s;
    } else {
      return LetStmt::make(op->var, value, body);
    }
  }
  // For
  Stmt Mutate_(const For* op, const Stmt& s) final {
    CHECK(is_zero(op->min));
    Expr extent = Mutate(op->extent);
    if (visit_touched_var_ && !vt_loop_injected_) {
      Stmt stmt = InjectVTLoop(s, true);
      ++max_loop_depth_;
      return stmt;
    }
    visit_touched_var_ = false;
    Stmt body = Mutate(op->body);
    ++max_loop_depth_;
    if (extent.same_as(op->extent) &&
        body.same_as(op->body)) {
      return s;
    } else {
      return For::make(
          op->loop_var, op->min, extent, op->for_type, op->device_api, body);
    }
  }
  // IfThenElse
  Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
    Expr condition = this->Mutate(op->condition);
    if (visit_touched_var_ && !vt_loop_injected_) {
      return InjectVTLoop(s, true);
    }
    visit_touched_var_ = false;
    CHECK_EQ(max_loop_depth_, 0);
    Stmt then_case = this->Mutate(op->then_case);
    Stmt else_case;
    if (else_case.defined()) {
      int temp = max_loop_depth_;
      max_loop_depth_ = 0;
      else_case = this->Mutate(op->else_case);
      max_loop_depth_ = std::max(temp, max_loop_depth_);
    }
    if (condition.same_as(op->condition) &&
        then_case.same_as(op->then_case) &&
        else_case.same_as(op->else_case)) {
      return s;
    } else {
      return IfThenElse::make(condition, then_case, else_case);
    }
  }
  // Block
  Stmt Mutate_(const Block* op, const Stmt& s) final {
    CHECK_EQ(max_loop_depth_, 0);
    Stmt first = this->Mutate(op->first);
    int temp = max_loop_depth_;
    max_loop_depth_ = 0;
    Stmt rest = this->Mutate(op->rest);
    max_loop_depth_ = std::max(max_loop_depth_, temp);
    if (first.same_as(op->first) &&
        rest.same_as(op->rest)) {
      return s;
    } else {
      return Block::make(first, rest);
    }
  }
  // Allocate
  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
    if (op->new_expr.defined() && !vt_loop_injected_) {
      return InjectVTLoop(s, true);
    }
    Expr condition = Mutate(op->condition);
    if (visit_touched_var_ && !vt_loop_injected_) {
      return InjectVTLoop(s, true);
    }

    bool changed = false;
    Array<Expr> extents;
    for (size_t i = 0; i < op->extents.size(); i++) {
      Expr new_ext = Mutate(op->extents[i]);
      if (visit_touched_var_ && !vt_loop_injected_) {
        return InjectVTLoop(s, true);
      }
      if (!new_ext.same_as(op->extents[i])) changed = true;
      extents.push_back(new_ext);
    }
    visit_touched_var_ = false;

    Stmt body;
    // always rewrite if not allow sharing.
    if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
      // place v on highest dimension.
      Expr stride = arith::ComputeReduce<Mul>(
          op->extents, Expr()) * op->type.lanes();
      Array<Expr> other;
      other.push_back(make_const(op->extents[0].type(), num_threads_));
      for (Expr e : extents) {
        other.push_back(e);
      }
      extents = other;
      changed = true;
      // mark this buffer get touched.
      alloc_remap_[op->buffer_var.get()] = stride;
      // Mutate the body.
      body = Mutate(op->body);
    } else {
      // Mutate the body.
      body = Mutate(op->body);
    }
    if (!changed &&
        body.same_as(op->body) &&
        condition.same_as(op->condition)) {
      return s;
    } else {
      return Allocate::make(
          op->buffer_var, op->type,
          extents, condition, body,
          op->new_expr, op->free_function);
    }
  }

  // inject vthread loop
  Stmt InjectVTLoop(Stmt stmt, bool before_mutation) {
    CHECK(!vt_loop_injected_);
    // reset the flags
    visit_touched_var_ = false;
    trigger_base_inject_ = false;
    vt_loop_injected_ = true;
    if (before_mutation) {
      stmt = this->Mutate(stmt);
    }
    // reset the flags after processing.
    vt_loop_injected_ = false;
    visit_touched_var_ = false;
    // only unroll if number of vthreads are small
    if (max_loop_depth_ == 0 && num_threads_ < 16) {
      // do unrolling if it is inside innermost content.
      Stmt blk = Substitute(stmt, {{var_, make_zero(var_.type())}});
      for (int i = 1; i < num_threads_; ++i) {
        blk = Block::make(
            blk, Substitute(stmt, {{var_, make_const(var_.type(), i)}}));
      }
      return blk;
    } else {
      // insert a for loop
      Var idx(var_->name_hint + ".s", var_->type);
      stmt = Substitute(stmt, {{var_, idx}});
      return For::make(idx, make_zero(idx.type()),
                       make_const(idx.type(), num_threads_),
                       ForType::Serial, DeviceAPI::None, stmt);
    }
  }

 private:
  // vthread variable
  Var var_;
  // the threads/lanes
  int num_threads_;
  // whethe the loop is already injected.
  bool vt_loop_injected_{false};
  // whether current expression get touched.
  bool visit_touched_var_{false};
  // Trigger base stmt
  bool trigger_base_inject_{false};
  // the counter of loops in after mutation.
  int max_loop_depth_{0};
  // The variables that get touched.
  const std::unordered_set<const Variable*>& touched_var_;
  // Whether allow shareding.
  bool allow_share_;
  // The allocations that get touched -> extent
  std::unordered_map<const Variable*, Expr> alloc_remap_;
};


class VirtualThreadInjector : public IRMutator {
 public:
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<AttrStmt>();
    if (op->attr_key == attr::virtual_thread) {
      IterVar iv(op->node.node_);
      bool allow_share = iv->thread_tag == "vthread";
      int nthread = static_cast<int>(op->value.as<IntImm>()->value);
      VarTouchedAnalysis vs;
      auto touched = vs.TouchedVar(op->body, iv->var.get());
      VTInjector injecter(iv->var, nthread, touched, allow_share);
      return injecter.Mutate(op->body);
    } else {
      return stmt;
    }
  }

  Stmt Mutate_(const Provide* op, const Stmt& s) final {
    LOG(FATAL) << "Need to call StorageFlatten first";
    return s;
  }
};

Stmt InjectVirtualThread(Stmt stmt) {
  stmt = VirtualThreadInjector().Mutate(stmt);
  return ConvertSSA(stmt);
}

}  // namespace ir
}  // namespace tvm