/*!
 *  Copyright (c) 2017 by Contributors
 *  Lower allreduce to device implementable ir.
 * \file lower_thread_allreduce.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"

namespace tvm {
namespace ir {

class ThreadAllreduceBuilder final : public IRMutator {
 public:
  explicit ThreadAllreduceBuilder(int warp_size)
      : warp_size_(warp_size) {}

  Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
    if (op->attr_key == attr::thread_extent) {
      thread_extents_.push_back(op);
      Stmt ret = IRMutator::Mutate_(op, s);
      thread_extents_.pop_back();
      return ret;
    } else if (op->attr_key == attr::storage_scope) {
      Stmt ret = IRMutator::Mutate_(op, s);
      op = ret.as<AttrStmt>();
      const Variable* v = op->node.as<Variable>();
      if (alloc_remap_.count(v)) {
        return op->body;
      } else {
        return ret;
      }
    } else if (op->attr_key == attr::reduce_scope) {
      const CommReducerNode *combiner = op->node.as<CommReducerNode>();
      CHECK(combiner);
      reduce_combiner_.push_back(combiner);
      Stmt ret = IRMutator::Mutate_(op, s);
      reduce_combiner_.pop_back();
      return ret;
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Evaluate>();
    const Call* call = op->value.as<Call>();
    if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
      return MakeAllreduce(call);
    } else {
      return stmt;
    }
  }
  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Allocate>();
    auto it = alloc_remap_.find(op->buffer_var.get());
    if (it != alloc_remap_.end()) {
      const Allocate* repl = it->second.as<Allocate>();
      // use volatile access to shared buffer.
      stmt = AttrStmt::make(
          repl->buffer_var, attr::volatile_scope, 1, op->body);
      stmt = Allocate::make(
          repl->buffer_var, repl->type,
          repl->extents, repl->condition, stmt);
      stmt = AttrStmt::make(
          repl->buffer_var, attr::storage_scope,
          StringImm::make("shared"), stmt);
      return stmt;
    } else {
      return stmt;
    }
  }
  Expr Mutate_(const Load* op, const Expr& e) final {
    auto it = load_remap_.find(op->buffer_var.get());
    if (it != load_remap_.end()) {
      CHECK(is_zero(op->index));
      return it->second;
    } else {
      return IRMutator::Mutate_(op, e);
    }
  }

 private:
  // Thread entry
  struct ThreadEntry {
    runtime::ThreadScope scope;
    IterVar iv;
    int extent;
    // comparator
    bool operator<(const ThreadEntry& other) const {
      return scope.dim_index < other.scope.dim_index;
    }
  };
  // make allreduce.
  Stmt MakeAllreduce(const Call* call) {
    CHECK(!reduce_combiner_.empty());
    const CommReducerNode *combiner = reduce_combiner_.back();
    size_t size = combiner->result.size();

    const UIntImm *size_of_args = call->args[0].as<UIntImm>();
    CHECK(size_of_args) << call->args[0]->type_key();
    CHECK_EQ(size, size_of_args->value);
    Array<Expr> inits = combiner->identity_element;
    std::vector<Expr> values(size);
    std::vector<Type> types(size);
    Expr cond  = call->args[size+1];
    for (size_t idx = 0; idx < size; ++idx) {
      values[idx] = call->args[1+idx];
      if (!is_one(cond)) {
        values[idx] = Select::make(cond, values[idx], inits[idx]);
      }
      types[idx] = values[idx].type();
    }
    std::vector<const Variable*> buffers(size);
    for (size_t idx = 0; idx < size; ++idx) {
      const Variable* buffer = call->args[2+size+idx].as<Variable>();
      CHECK(buffer);
      buffers[idx] = buffer;
    }

    std::unordered_set<const Variable*> reduce_set;
    for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
      const Variable* v = call->args[i].as<Variable>();
      CHECK(v);
      reduce_set.insert(v);
    }
    size_t nmatch = 0;
    std::vector<ThreadEntry> vred, vpar;
    for (const AttrStmt* attr : thread_extents_) {
      ThreadEntry e;
      IterVar iv(attr->node.node_);
      e.scope = runtime::ThreadScope::make(iv->thread_tag);
      e.iv = iv;
      CHECK_LE(e.scope.rank, 1);
      CHECK_GE(e.scope.dim_index, 0)
          << "vthread do not work with cross thread reduction";
      if (e.scope.rank == 1) {
        CHECK(arith::GetConstInt(attr->value, &(e.extent)))
            << "Need constant extent for reduce set " << iv;
        if (reduce_set.count(iv->var.get())) {
          vred.push_back(e);
          ++nmatch;
        } else {
          vpar.push_back(e);
        }
      }
    }
    CHECK_EQ(nmatch, reduce_set.size())
        << "Not all reduce index are presented in the context";
    std::sort(vred.begin(), vred.end());
    std::sort(vpar.begin(), vpar.end());
    // the size of each index.
    int reduce_extent, group_extent;
    int threadx_extent = 1;
    Expr reduce_index = FlattenThread(vred, &reduce_extent);
    Expr group_index = FlattenThread(vpar, &group_extent);
    if (reduce_extent == 1) {
      // special case, no reduction is needed.
      std::vector<Stmt> stores(size);
      for (size_t i = 0; i < size; ++i) {
        Expr pred = const_true(types[i].lanes());
        Var buffer_var(call->args[2+size+i].node_);
        stores[i] = Store::make(buffer_var, values[i], 0, pred);
      }
      return Block::make(stores);
    }
    // Whether the threadIdx.x is involved in reduction.
    if (vred[0].scope.dim_index == 0) {
      threadx_extent = vred[0].extent;
    }
    std::vector<Stmt> seq;
    std::vector<Var> shared_bufs(size);
    for (size_t idx = 0; idx < size; ++idx) {
      shared_bufs[idx] = Var("red_buf"+std::to_string(idx), Handle());
      Expr pred = const_true(types[idx].lanes());
      seq.emplace_back(Store::make(
          shared_bufs[idx], values[idx],
          BufIndex(reduce_index, group_index, reduce_extent), pred));
    }
    seq.emplace_back(SyncThread("shared"));
    seq.emplace_back(MakeBufAllreduce(
        combiner, types, shared_bufs,
        reduce_index, group_index, reduce_extent, threadx_extent));
    for (size_t idx = 0; idx < size; ++idx) {
      CHECK(!load_remap_.count(buffers[idx]));
      Expr pred = const_true(types[idx].lanes());
      load_remap_[buffers[idx]] = Load::make(
        types[idx], shared_bufs[idx],
        BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), pred);
      alloc_remap_[buffers[idx]] = Allocate::make(
        shared_bufs[idx], types[idx],
        {Expr(group_extent), Expr(reduce_extent)},
        pred, Evaluate::make(0));
    }
    return MergeSeq(seq);
  }
  // make allreduce.
  Stmt MakeBufAllreduce(const CommReducerNode *combiner,
                        const std::vector<Type>& types,
                        const Array<Var>& shared_bufs,
                        Expr reduce_index,
                        Expr group_index,
                        int reduce_extent,
                        int threadx_extent) {
    // Get next power of two
    int reduce_align = 1;
    while (reduce_extent > reduce_align) {
      reduce_align = reduce_align << 1;
    }
    CHECK_GT(reduce_align, 1);
    std::vector<Stmt> seq;

    size_t size = shared_bufs.size();
    Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
    // make reduction
    auto freduce = [&](int offset) {
      Array<Expr> a, b;
      for (size_t i = 0; i < size; ++i) {
        b.push_back(Load::make(types[i], shared_bufs[i],
          BufIndex(reduce_index + offset, group_index, reduce_extent),
          const_true()));
        a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true()));
      }
      Array<Expr> ret = (*combiner)(a, b);
      std::vector<Stmt> stores(size);
      for (size_t i = 0; i < size; ++i) {
        stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true());
      }
      return Block::make(stores);
    };
    // Step one, check for
    if (reduce_align > reduce_extent) {
      // reduction with the boundary condition
      reduce_align = reduce_align >> 1;
      Expr cond = reduce_index < (reduce_extent - reduce_align);
      seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
      seq.emplace_back(SyncThread("shared"));
    }
    CHECK(threadx_extent >= 1 && warp_size_ >= 1);
    // normal synchronization
    while (reduce_align > threadx_extent ||
           reduce_align > warp_size_) {
      reduce_align =  reduce_align >> 1;
      Expr cond = reduce_index < reduce_align;
      seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
      seq.emplace_back(SyncThread("shared"));
    }
    // in warp synchronization.
    std::vector<Stmt> in_warp_seq;
    Expr in_warp_cond = reduce_index < (reduce_align >> 1);
    while (reduce_align > 1) {
      reduce_align = reduce_align >> 1;
      in_warp_seq.emplace_back(freduce(reduce_align));
      seq.emplace_back(SyncThread("warp"));
    }
    if (in_warp_seq.size() != 0) {
      Stmt warp_body = MergeSeq(in_warp_seq);
      seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body));
    }
    return MergeSeq(seq);
  }
  // Flatten the thread index.
  // Also return a warp number,
  Expr FlattenThread(const std::vector<ThreadEntry>& tvec,
                     int* out_total_extent) {
    int& total_extent = *out_total_extent;
    total_extent = 1;
    if (tvec.size() == 0) {
      return make_zero(Int(32));
    }

    Expr ret;
    for (const ThreadEntry& e : tvec) {
      if (ret.defined()) {
        ret = ret + e.iv->var * total_extent;
      } else {
        CHECK_EQ(total_extent, 1);
        ret = e.iv->var;
      }
      total_extent *= e.extent;
    }
    return ret;
  }
  // sync thread op.
  static Stmt SyncThread(const std::string& sync) {
    return Evaluate::make(
        Call::make(Int(32), intrinsic::tvm_storage_sync,
                   {StringImm::make(sync)},
                   Call::Intrinsic));
  }
  // The local buffer index.
  static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) {
    if (!is_zero(group_index)) {
      return ir::Simplify(group_index * reduce_extent + reduce_index);
    } else {
      return reduce_index;
    }
  }
  // The warp size of the device.
  int warp_size_{1};

  // surrounding scope of thread extent.
  std::vector<const AttrStmt*> thread_extents_;
  std::vector<const CommReducerNode*> reduce_combiner_;
  // The load remap
  std::unordered_map<const Variable *, Expr> load_remap_;
  // Allocate remap
  std::unordered_map<const Variable *, Stmt> alloc_remap_;
};

LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) {
  CHECK_NE(f->func_type, kHostFunc);
  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
  n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body);
  return LoweredFunc(n);
}
}  // namespace ir
}  // namespace tvm