/*!
 *  Copyright (c) 2017 by Contributors
 * \brief Logics related to cross thread reduction, used by ComputeOpNode.
 * \file cross_thread_reduction.cc
 */
#include <tvm/ir_pass.h>
#include "./compute_op.h"
#include "./op_util.h"

namespace tvm {
using namespace ir;

Stmt MakeCrossThreadReduction(
    const ComputeOpNode* self,
    const Stage& stage,
    const std::unordered_map<IterVar, Range>& dom_map,
    bool debug_keep_trivial_loop) {
  Array<Expr>  args;
  for (IterVar iv : self->axis) {
    args.push_back(iv->var);
  }
  std::unordered_map<IterVar, Expr> value_map;
  auto nest = op::MakeLoopNest(
      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop);
  auto conds = schedule::MakeBoundCheck(
      stage, dom_map, value_map, false,
      std::unordered_set<IterVar>());

  size_t size = self->body.size();
  CHECK_GT(size, 0);
  std::vector<const Reduce*> reduces(size);
  for (size_t i = 0; i < size; ++i) {
    const Reduce* reduce = self->body[i].as<Reduce>();
    CHECK(reduce);
    reduces[i] = reduce;
  }
  Expr cond = reduces[0]->condition;
  for (Expr v : conds) {
    cond = cond && v;
  }
  Array<Expr> freduce_args;
  freduce_args.push_back(make_const(UInt(32), static_cast<uint32_t>(size)));
  for (size_t i = 0; i < size; ++i) {
    freduce_args.push_back(reduces[0]->source[i]);
  }
  freduce_args.push_back(cond);
  std::vector<Var> res_handles(size);
  for (size_t idx = 0; idx < size; ++idx) {
    res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle());
    freduce_args.push_back(res_handles[idx]);
  }

  for (IterVar iv : stage->leaf_iter_vars) {
    if (iv->iter_type == kCommReduce) {
      auto it = stage->iter_var_attrs.find(iv);
      if (it != stage->iter_var_attrs.end() &&
          (*it).second->bind_thread.defined()) {
        IterVar tv = (*it).second->bind_thread;
        freduce_args.push_back(tv->var);
      }
    }
  }
  // Checks for the thread.
  std::vector<Expr> thread_head_check;
  if (stage->store_predicate.defined()) {
    thread_head_check.emplace_back(stage->store_predicate);
  }

  Stmt reduce_body = Evaluate::make(Call::make(
      Handle(),
      ir::intrinsic::tvm_thread_allreduce,
      freduce_args, Call::Intrinsic));
  reduce_body = AttrStmt::make(
      reduces[0]->combiner,
      attr::reduce_scope,
      make_zero(Handle()),
      reduce_body);
  std::vector<Stmt> assigns(size);
  for (size_t idx = 0; idx < size; ++idx) {
    Type t = reduces[idx]->type;
    assigns[idx] = Provide::make(
      stage->op, idx,
      Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
  }
  Stmt assign_body = Block::make(assigns);
  assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
  assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
  Stmt body = Block::make(reduce_body, assign_body);
  for (size_t idx = size; idx != 0; --idx) {
    body = Allocate::make(
      res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body);
    body = AttrStmt::make(
      res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
  }
  body = op::Substitute(body, value_map);
  return MergeNest(nest, body);
}
}  // namespace tvm