/*!
 *  Copyright (c) 2017 by Contributors
 *
 * \brief Inject double buffering optimization for data fetch.
 * \file inject_double_buffer.cc
 */
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace ir {

// Detect double buffer variables.
class DoubleBufferDetector : public IRVisitor {
 public:
  void Visit_(const AttrStmt* op) final {
    if (op->attr_key == attr::double_buffer_scope) {
      touched_.insert(op->node.as<Variable>());
      IRVisitor::Visit_(op);
    } else {
      IRVisitor::Visit_(op);
    }
  }

  void Visit_(const Variable* op) final {
    if (touched_.count(op)) {
      touched_.erase(op);
    }
  }
  // The set of touched variable.
  std::unordered_set<const Variable*> touched_;
};


class StripDoubleBufferWrite : public IRMutator {
 public:
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    if (op->attr_key == attr::double_buffer_write) {
      return Mutate(op->body);
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
};

class DoubleBufferInjector : public IRMutator {
 public:
  explicit DoubleBufferInjector(int split_loop)
      : split_loop_(split_loop) {}

  Stmt Inject(const Stmt& stmt) {
    DoubleBufferDetector detector;
    detector.Visit(stmt);
    if (detector.touched_.empty()) return stmt;
    for (const Variable* v : detector.touched_) {
      dbuffer_info_[v] = StorageEntry();
    }
    return ConvertSSA(this->Mutate(stmt));
  }

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    if (op->attr_key == attr::storage_scope) {
      const Variable* buf = op->node.as<Variable>();
      auto it = dbuffer_info_.find(buf);
      if (it != dbuffer_info_.end()) {
        it->second.scope = op->value.as<StringImm>()->value;
        return Mutate(op->body);
      } else {
        return IRMutator::Mutate_(op, s);
      }
    } else if (op->attr_key == attr::double_buffer_scope) {
      return MakeProducer(op, s);
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
    auto it = dbuffer_info_.find(op->buffer_var.get());
    if (it != dbuffer_info_.end()) {
      it->second.stride = arith::ComputeReduce<Mul>
          (op->extents, Expr()) * op->type.lanes();
      Stmt stmt = IRMutator::Mutate_(op, s);
      op = stmt.as<Allocate>();
      Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
      for (Expr e : op->extents) {
        new_extents.push_back(e);
      }
      CHECK(it->second.loop != nullptr);
      auto& alloc_nest = loop_allocs_[it->second.loop];
      alloc_nest.emplace_back(AttrStmt::make(
          op->buffer_var, attr::storage_scope,
          StringImm::make(it->second.scope),
          Evaluate::make(0)));
      alloc_nest.emplace_back(Allocate::make(
          op->buffer_var, op->type, new_extents, op->condition,
          Evaluate::make(0)));
      return op->body;
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  Stmt Mutate_(const For* op, const Stmt& s) final {
    loop_nest_.push_back(op);
    Stmt stmt = IRMutator::Mutate_(op, s);
    auto it = loop_pre_.find(op);
    if (it != loop_pre_.end()) {
      const For* old_loop = stmt.as<For>();
      if (split_loop_ != 0) {
        // Explicitly unroll the loop
        CHECK(split_loop_ % 2 == 0 || split_loop_ == 1)
            << "It is better to split with multiple of 2";
        CHECK(is_zero(old_loop->min));
        Expr zero = old_loop->min;
        Expr new_ext = arith::ComputeExpr<Sub>(
            old_loop->extent, make_const(old_loop->loop_var.type(), 1));
        Expr factor = make_const(new_ext.type(), split_loop_);
        Expr outer_ext = arith::ComputeExpr<Div>(new_ext, factor);
        Expr tail_base = arith::ComputeExpr<Mul>(outer_ext, factor);
        Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type());
        std::unordered_map<const Variable*, Expr> vmap;
        std::vector<Stmt> loop_seq;
        for (int32_t i = 0; i < split_loop_; ++i) {
          vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.type(), i);
          loop_seq.emplace_back(Substitute(old_loop->body, vmap));
        }
        Stmt loop = For::make(
            outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
            MergeSeq(loop_seq));
        // tail
        std::vector<Stmt> tail_seq;
        Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body);
        for (int32_t i = 0; i < split_loop_; ++i) {
          Expr idx = tail_base + make_const(tail_base.type(), i);
          vmap[old_loop->loop_var.get()] = idx;
          tail_seq.emplace_back(
              IfThenElse::make(idx < old_loop->extent,
                               Substitute(tail_body, vmap)));
        }
        stmt = Block::make(loop, MergeSeq(tail_seq));
      }
      stmt = Block::make(MergeSeq(it->second), stmt);
    }
    it = loop_allocs_.find(op);
    if (it != loop_allocs_.end()) {
      stmt = MergeNest(it->second, stmt);
    }
    loop_nest_.pop_back();
    return stmt;
  }

  Stmt Mutate_(const Store* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Store>();
    auto it = dbuffer_info_.find(op->buffer_var.get());
    if (it != dbuffer_info_.end()) {
      const StorageEntry& e = it->second;
      CHECK(in_double_buffer_scope_);
      CHECK(e.stride.defined());
      return Store::make(op->buffer_var,
                         op->value,
                         e.switch_write_var * e.stride + op->index,
                         op->predicate);
    } else {
      return stmt;
    }
  }

  Expr Mutate_(const Load* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Load>();
    auto it = dbuffer_info_.find(op->buffer_var.get());
    if (it != dbuffer_info_.end()) {
      const StorageEntry& e = it->second;
      CHECK(e.stride.defined());
      CHECK(e.switch_read_var.defined());
      return Load::make(op->type,
                        op->buffer_var,
                        e.switch_read_var * e.stride + op->index,
                        op->predicate);
    } else {
      return expr;
    }
  }

  Expr Mutate_(const Variable* op, const Expr& e) final {
    CHECK(!dbuffer_info_.count(op));
    return e;
  }

 private:
  Stmt MakeProducer(const AttrStmt* op, const Stmt& s) {
    const VarExpr buffer(op->node.node_);
    CHECK_NE(loop_nest_.size(), 0U)
        << "Double buffer scope must be inside a loop";
    auto it = dbuffer_info_.find(buffer.get());
    if (it == dbuffer_info_.end()) {
      LOG(WARNING) << "Skip double buffer scope " << op->node;
      return Mutate(op->body);
    }
    StorageEntry& e = it->second;
    e.loop = loop_nest_.back();
    Expr zero = make_const(e.loop->loop_var.type(), 0);
    Expr one = make_const(e.loop->loop_var.type(), 1);
    Expr two = make_const(e.loop->loop_var.type(), 2);
    Expr loop_shift = e.loop->loop_var + one;
    e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
                             e.loop->loop_var.type());
    e.switch_read_var = e.loop->loop_var % two;
    in_double_buffer_scope_ = true;
    Stmt body = Mutate(op->body);
    in_double_buffer_scope_ = false;
    std::unordered_map<const Variable*, Expr> vmap;
    vmap[e.switch_write_var.get()] = zero;
    vmap[e.loop->loop_var.get()] = zero;
    loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
    vmap[e.loop->loop_var.get()] = loop_shift;
    vmap[e.switch_write_var.get()] = loop_shift % two;
    body = Substitute(body, vmap);
    body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body);
    body = IfThenElse::make(loop_shift < e.loop->extent, body);
    return body;
  }
  // Storage entry for those who need double buffering.
  struct StorageEntry {
    // The size of the buffer
    Expr stride;
    // The loop we need
    const For* loop{nullptr};
    // The switch variable.
    VarExpr switch_write_var;
    // The switch variable for reading.
    Expr switch_read_var;
    // The storage scope.
    std::string scope;
  };
  // Whether split loop
  int32_t split_loop_;
  // Whether we are inside double buffer scope.
  bool in_double_buffer_scope_{false};
  // The current loop next
  std::vector<const For*> loop_nest_;
  // The allocs to be appended before the loop
  std::unordered_map<const For*, std::vector<Stmt> > loop_allocs_;
  // The stmt to be appended before the loop
  std::unordered_map<const For*, std::vector<Stmt> > loop_pre_;
  // The allocation size of the buffer
  std::unordered_map<const Variable*, StorageEntry> dbuffer_info_;
};


Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) {
  return DoubleBufferInjector(split_loop).Inject(stmt);
}
}  // namespace ir
}  // namespace tvm