/*!
 *  Copyright (c) 2017 by Contributors
 * \brief Logics related to tensorize, used by ComputeOpNode.
 * \file tensorize.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include "./op_util.h"
#include "./compute_op.h"
#include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"

namespace tvm {

using namespace ir;
using namespace op;

// Detect the region of input and output to be tensrized.
// out_dom: the domain of root iter vars in output op
// in_region: region of each input tensor.
// return The location of the tensorized scope start.
size_t InferTensorizeRegion(
    const ComputeOpNode* self,
    const Stage& stage,
    const std::unordered_map<IterVar, Range>& dom_map,
    std::unordered_map<IterVar, Range>* out_dom,
    std::unordered_map<Tensor, Array<Range> >* in_region) {
  // Get the bound of the tensorized scope.
  bool found_point = false;
  size_t loc_scope = 0;
  std::unordered_map<IterVar, IntSet> up_state;
  // Loop over the leafs
  for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) {
    IterVar iv = stage->leaf_iter_vars[i - 1];
    CHECK(iv->iter_type == kDataPar ||
          iv->iter_type == kCommReduce);
    auto vit = dom_map.find(iv);
    CHECK(vit != dom_map.end());
    const Range& vrange = vit->second;
    if (is_one(vrange->extent)) {
      up_state[iv] = IntSet::single_point(vrange->min);
    } else if (found_point) {
      CHECK(is_zero(vrange->min));
      up_state[iv] = IntSet::single_point(iv->var);
    } else {
      up_state[iv] = IntSet::range(vrange);
    }
    auto iit = stage->iter_var_attrs.find(iv);
    if (iit != stage->iter_var_attrs.end()) {
      const IterVarAttr& attr = (*iit).second;
      if (!found_point) {
        CHECK(!attr->bind_thread.defined())
            << "Donot allow thread in tensorize scope";
      }
      if (attr->iter_type == kTensorized) {
        CHECK(!found_point) << "Donot allow two tensorized point";
        found_point = true;
        loc_scope = i - 1;
      }
    }
  }
  CHECK(found_point);
  // Get domain of the tensorized scope.
  schedule::PassUpDomain(stage, dom_map, &up_state);
  // Get domains if inputs
  std::unordered_map<Tensor, TensorDom> in_dom;
  std::unordered_map<const Variable*, IntSet> temp_dmap;
  Array<Tensor> inputs = self->InputTensors();
  for (Tensor t : inputs) {
    in_dom.emplace(t, TensorDom(t.ndim()));
  }
  for (IterVar iv : self->root_iter_vars()) {
    IntSet iset = up_state.at(iv);
    (*out_dom)[iv] = iset.cover_range(dom_map.at(iv));
    temp_dmap[iv->var.get()] = iset;
  }
  // Input domains
  self->PropBoundToInputs(stage->op, temp_dmap, &in_dom);
  Range none;
  for (const auto& kv : in_dom) {
    Array<Range> vec;
    const Tensor& t = kv.first;
    for (size_t i = 0; i < t.ndim(); ++i) {
      Range r = arith::Union(kv.second.data.at(i)).cover_range(none);
      CHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t;
      vec.push_back(std::move(r));
    }
    (*in_region)[t] = std::move(vec);
  }
  return loc_scope;
}

void VerifyTensorizeLoopNest(const ComputeOpNode* self,
                             const Stage& stage,
                             const ComputeLoopNest& n,
                             size_t tloc) {
  // Veirfication step.
  std::unordered_set<const Variable*> banned;
  CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
  CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 ||
        n.init_nest.size() == 0);
  auto f_push_banned = [&banned](const Stmt& s) {
    if (const For* op = s.as<For>()) {
        banned.insert(op->loop_var.get());
    } else if (const AttrStmt* op = s.as<AttrStmt>()) {
      if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
        banned.insert(iv->var.get());
      }
    } else if (const LetStmt* op = s.as<LetStmt>()) {
      banned.insert(op->var.get());
    }
  };
  for (size_t i = tloc; i < stage->leaf_iter_vars.size(); ++i) {
    for (const Stmt& s : n.main_nest[i + 1]) {
      f_push_banned(s);
    }
    if (n.init_nest.size() != 0) {
      for (const Stmt& s : n.init_nest[i + 1]) {
        f_push_banned(s);
      }
    }
  }
  for (const Expr& pred : n.main_predicates) {
    if (ir::ExprUseVar(pred, banned)) {
      LOG(FATAL) << "Tensorize failed, split condition "
                 << pred << " relies on var defined inside tensorize scope";
    }
  }
  for (const Expr& pred : n.init_predicates) {
    if (ir::ExprUseVar(pred, banned)) {
      LOG(FATAL) << "Tensorize failed, split condition "
                 << pred << " relies on var defined inside tensorize scope";
    }
  }
}


// Remap the tensor placeholder, index and inline things.
class TensorIntrinMatcher final : public IRMutator {
 public:
  Expr Mutate_(const Call* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Call>();
    if (op->call_type == Call::Halide) {
      Tensor t = Operation(op->func.node_).output(op->value_index);
      auto it = in_remap_.find(t);
      if (it != in_remap_.end()) {
        const InputEntry& e = it->second;
        CHECK_EQ(op->args.size(), e.region.size());
        Array<Expr> args;
        for (size_t i = e.start; i < e.region.size(); ++i) {
          args.push_back(op->args[i] - e.region[i]->min);
        }
        return Call::make(
            op->type, e.tensor->op->name, args,
            op->call_type, e.tensor->op, e.tensor->value_index);
      }
    }
    return expr;
  }

  Expr Mutate_(const Variable* op, const Expr& e) final {
    auto it = var_remap_.find(op);
    if (it != var_remap_.end()) {
      return it->second;
    } else {
      return e;
    }
  }

  Expr Mutate_(const Reduce* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Reduce>();
    Array<IterVar> axis;
    for (size_t i = 0; i < op->axis.size(); ++i) {
      auto it = axis_remap_.find(op->axis[i]);
      if (it != axis_remap_.end()) {
        axis.push_back(it->second);
      }
    }
    return Reduce::make(
        op->combiner, op->source, axis, op->condition, op->value_index);
  }

  void Init(const ComputeOpNode* self,
            const Stage& stage,
            const std::unordered_map<IterVar, Range>& out_dom,
            const std::unordered_map<Tensor, Array<Range> >& in_region,
            const TensorIntrin& intrin,
            Map<Var, Range>* compute_intrin_iter_space) {
    CHECK(self == stage->op.get());
    // input remap.
    Array<Tensor> inputs = self->InputTensors();
    CHECK_EQ(inputs.size(), intrin->inputs.size());
    for (size_t i = 0; i < inputs.size(); ++i) {
      InputEntry e;
      e.tensor = intrin->inputs[i];
      e.region = Array<Range>(in_region.at(inputs[i]));
      CHECK_GE(e.region.size(), e.tensor.ndim());
      // Enable fuzzy matching, to match [1, n, m] to [n, m]
      e.start = e.region.size() - e.tensor.ndim();
      for (size_t i = 0; i < e.start; ++i) {
        CHECK(is_one(e.region[i]->extent))
            << "Tensorize " << intrin->name << ":"
            << " Input dimension mismatch with tensor intrin "
            << " expected shape=" << e.tensor->shape
            << ", given region=" << e.region;
      }
      in_remap_[inputs[i]] = e;
    }
    // output remap
    const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
    CHECK(intrin_compute) << "Only support compute intrinsic for now";
    CHECK_GE(self->axis.size(), intrin_compute->axis.size())
        << "Tensorize: Output mismatch with tensor intrin ";
    // Enable fuzzy matching, to match [1, n, m] to [n, m]
    size_t axis_start = self->axis.size() - intrin_compute->axis.size();
    for (size_t i = 0; i < axis_start; ++i) {
      Range r = out_dom.at(self->axis[i]);
      CHECK(is_one(r->extent))
          << "Tensorize: Output mismatch with tensor intrin "
          << " intrin-dim=" << intrin_compute->axis.size()
          << ", tensorize-dim=" << self->axis.size();
      var_remap_[self->axis[i]->var.get()] = r->min;
    }
    // Assume we tensorize at regin axis i [min, min + extent)
    // The corresponding intrinsic axis is j [0, extent)
    // Remap index i to j + min
    for (size_t i = axis_start; i < self->axis.size(); ++i) {
      IterVar iv = self->axis[i];
      IterVar target_iv = intrin_compute->axis[i - axis_start];
      Range r = out_dom.at(iv);
      var_remap_[iv->var.get()] = target_iv->var + r->min;
      axis_remap_[iv] = target_iv;
      compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
    }
    // Remap reduction axis
    CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
        << "Tensorize: Reduction dimension mismatch with tensor intrin";
    axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
    for (size_t i = 0; i < axis_start; ++i) {
      Range r = out_dom.at(self->reduce_axis[i]);
      CHECK(is_one(r->extent))
          << "Tensorize: Reduction mismatch with tensor intrin "
          << " intrin-dim=" << intrin_compute->reduce_axis.size()
          << ", tensorize-dim=" << self->reduce_axis.size();
      var_remap_[self->reduce_axis[i]->var.get()] = r->min;
    }
    for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
      IterVar iv = self->reduce_axis[i];
      IterVar target_iv = intrin_compute->reduce_axis[i - axis_start];
      Range r = out_dom.at(iv);
      var_remap_[iv->var.get()] = target_iv->var + r->min;
      axis_remap_[iv] = target_iv;
      compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
    }
  }

 private:
  // Input entry
  struct InputEntry {
    Tensor tensor;
    size_t start;
    Array<Range> region;
  };
  // input data remap
  std::unordered_map<Tensor, InputEntry> in_remap_;
  // variable remap.
  std::unordered_map<const Variable*, Expr> var_remap_;
  // IterVar remap.
  std::unordered_map<IterVar, IterVar> axis_remap_;
};

// Try to match tensor dataflow of the stage with the intrinsic
Array<Expr> MatchTensorizeBody(
    const ComputeOpNode* self,
    const Stage& stage,
    const std::unordered_map<IterVar, Range>& out_dom,
    const std::unordered_map<Tensor, Array<Range> >& in_region,
    const TensorIntrin& intrin,
    Map<Var, Range>* compute_intrin_iter_space) {
  TensorIntrinMatcher matcher;
  matcher.Init(self, stage, out_dom, in_region, intrin, compute_intrin_iter_space);
  Array<Expr> ret;
  for (Expr expr : self->body) {
    ret.push_back(matcher.Mutate(expr));
  }
  return ret;
}

void VerifyTensorizeBody(
    const ComputeOpNode* self,
    const Stage& stage,
    const std::unordered_map<IterVar, Range>& out_dom,
    const std::unordered_map<Tensor, Array<Range> >& in_region,
    const TensorIntrin& intrin) {
  Map<Var, Range> compute_intrin_iter_space;
  Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin,
                                        &compute_intrin_iter_space);
  const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
  CHECK(intrin_compute) << "Only support compute intrinsic for now";
  CHECK_EQ(body.size(), intrin_compute->body.size())
      << "Tensorize failed: body size mismatch";
  for (size_t i = 0; i < body.size(); ++i) {
    Expr lhs = Simplify(body[i], compute_intrin_iter_space);
    lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
    Expr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
    rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
    if (lhs.type() != rhs.type()) {
      LOG(FATAL)
          << "Failed to match the data type with TensorIntrin "
          << intrin->name << "'s declaration "
          << " provided=" << lhs.type()
          << ", intrin=" << rhs.type();
    }
    CHECK(Equal(lhs, rhs))
        << "Failed to match the compute with TensorIntrin "
        << intrin->name << "'s declaration "
        << " provided= " << lhs
        << ", intrin=  " << rhs;
  }
}

/*!
 * \brief Transform the update part when there is no init func in tensorizing
 * \param stage The stage for tensorizing.
 * \param dom_map The range of each iter var.
 * \param n The loop nest structured used in compute. 
 * \param body The body func in tensorize intrin
 * \param update The update func in tensorize intrin
 * \return Transformed result.
 */
Stmt TransformUpdate(const Stage& stage,
                     const std::unordered_map<IterVar, Range>& dom_map,
                     const ComputeLoopNest& n,
                     Stmt body,
                     Stmt update) {
  Array<Expr> conds;
  std::unordered_set<const Variable*> banned;
  for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
    IterVar iv = stage->leaf_iter_vars[i];
    auto iit = stage->iter_var_attrs.find(iv);
    if (iit != stage->iter_var_attrs.end()) {
      const IterVarAttr& attr = (*iit).second;
      if (attr->iter_type == kTensorized) {
        break;
      }
    }
    if (iv->iter_type == kCommReduce) {
      auto vit = dom_map.find(iv);
      CHECK(vit != dom_map.end());
      const Range& vrange = vit->second;
      conds.push_back(likely(iv->var > vrange->min));
      banned.insert(iv->var.get());
    }
  }
  for (const Expr& pred : n.main_predicates) {
    if (ir::ExprUseVar(pred, banned)) {
      LOG(FATAL) << "Tensorize update transform failed, the condition "
                 << pred << " has a conflict with the reset condition";
    }
  }

  return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
                          update, body);
}

Stmt MakeTensorize(const ComputeOpNode* self,
                   const Stage& stage,
                   const std::unordered_map<IterVar, Range>& dom_map,
                   bool debug_keep_trivial_loop) {
  std::unordered_map<IterVar, Range> out_dom;
  std::unordered_map<Tensor, Array<Range> > in_region;
  size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
  TensorIntrin intrin = stage->iter_var_attrs.at(
      stage->leaf_iter_vars[tloc])->tensor_intrin;
  CHECK(intrin.defined());
  ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
  VerifyTensorizeLoopNest(self, stage, n, tloc);
  VerifyTensorizeBody(self, stage, out_dom, in_region, intrin);
  // Start bind data.
  Stmt nop = Evaluate::make(0);
  std::vector<Stmt> input_bind_nest, output_bind_nest;
  Array<Tensor> inputs = self->InputTensors();
  CHECK_EQ(inputs.size(), intrin->inputs.size())
      << "Tensorize failed: input size mismatch ";
  // input binding
  for (size_t i = 0; i < intrin->inputs.size(); ++i) {
    Tensor tensor = inputs[i];
    Buffer buffer = intrin->buffers[i];
    Array<NodeRef> bind_spec{buffer, tensor};
    auto it = in_region.find(tensor);
    CHECK(it != in_region.end());
    const Array<Range>& region = it->second;
    Array<Expr> tuple;
    for (const Range r : region) {
      tuple.push_back(r->min);
      tuple.push_back(r->extent);
    }
    input_bind_nest.emplace_back(AttrStmt::make(
        bind_spec, ir::attr::buffer_bind_scope,
        Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
  }
  // output binding
  const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
  CHECK(intrin_compute) << "Only support compute intrinsic for now";
  CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size());
  CHECK_EQ(intrin_compute->body.size(), self->body.size());
  Array<Expr> tuple;
  for (IterVar iv : self->axis) {
    auto it = out_dom.find(iv);
    CHECK(it != out_dom.end());
    tuple.push_back(it->second->min);
    tuple.push_back(it->second->extent);
  }
  for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) {
    Tensor tensor = stage->op.output(i - intrin->inputs.size());
    Buffer buffer = intrin->buffers[i];
    Array<NodeRef> bind_spec{buffer, tensor};
    output_bind_nest.emplace_back(AttrStmt::make(
        bind_spec, ir::attr::buffer_bind_scope,
        Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
  }
  // Check variable remap
  std::unordered_map<const Variable*, Expr> vmap;
  ir::ArgBinder binder(&vmap);
  CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
      << "Tensorization fail: reduction axis size do not match";
  size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
  for (size_t i = 0; i < start; ++i) {
    IterVar iv = self->reduce_axis[i];
    auto it = out_dom.find(iv);
    CHECK(it != out_dom.end());
    CHECK(is_one(it->second->extent))
        << "Tensorization fail: reduction axis size do not match";
  }
  for (size_t i = start; i < self->reduce_axis.size(); ++i) {
    IterVar iv = self->reduce_axis[i];
    IterVar target = intrin_compute->reduce_axis[i - start];
    auto it = out_dom.find(iv);
    CHECK(it != out_dom.end());
    binder.Bind(target->dom->min, make_const(iv->dom->min.type(), 0),
                "tensir_intrin.reduction.min");
    binder.Bind(target->dom->extent, it->second->extent,
                "tensir_intrin.reduction.extent");
  }
  if (tloc <= n.num_common_loop) {
    // Do no need to split reduction
    std::vector<std::vector<Stmt> > nest(
        n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
    nest.emplace_back(op::MakeIfNest(n.main_predicates));
    CHECK_EQ(n.init_predicates.size(), 0U);
    CHECK(intrin->body.defined())
        << "Normal store op for intrin " << intrin << " is not defined";
    Stmt body = MergeNest(output_bind_nest, intrin->body);
    body = MergeNest(input_bind_nest, body);
    body = Substitute(body, vmap);
    body = MergeNest(binder.asserts(), body);
    body = Substitute(body, n.main_vmap);
    return MergeNest(nest, body);
  } else {
    // Need to split reduction
    CHECK(intrin->reduce_update.defined())
        << "Reduction update op for intrin " << intrin << " is not defined";
    // Need init and update steps
    CHECK_NE(self->reduce_axis.size(), 0U);
    std::vector<std::vector<Stmt> > common(
        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
    std::vector<std::vector<Stmt> > update_nest(
        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
    update_nest.emplace_back(op::MakeIfNest(n.main_predicates));

    if (intrin->reduce_init.defined()) {
      // init nest
      std::vector<std::vector<Stmt> > init_nest(
          n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
      init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
      Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
      init = Substitute(init, n.init_vmap);
      init = MergeNest(init_nest, init);
      // The update
      Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
      update = MergeNest(input_bind_nest, update);
      update = Substitute(update, vmap);
      update = MergeNest(binder.asserts(), update);
      update = Substitute(update, n.main_vmap);
      update = MergeNest(update_nest, update);
      return MergeNest(common, Block::make(init, update));
    } else {
      // When init op is not available, use body op for reset in the first iter.
      CHECK(intrin->body.defined())
          << "Normal body op for intrin " << intrin << " is not defined";
      Stmt update = TransformUpdate(stage, dom_map, n,
                                    intrin->body,
                                    intrin->reduce_update);
      update = MergeNest(output_bind_nest, update);
      update = MergeNest(input_bind_nest, update);
      update = Substitute(update, vmap);
      update = MergeNest(binder.asserts(), update);
      update = Substitute(update, n.main_vmap);
      update = MergeNest(update_nest, update);
      return MergeNest(common, update);
    }
  }
}

// Register functions for unittests
TVM_REGISTER_API("test.op.InferTensorizeRegion")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    Stage stage = args[0];
    Map<IterVar, Range> dmap = args[1];
    std::unordered_map<IterVar, Range> out_dom;
    std::unordered_map<Tensor, Array<Range> > in_region;
    CHECK(stage->op.as<ComputeOpNode>());
    InferTensorizeRegion(stage->op.as<ComputeOpNode>(),
                         stage,
                         as_unordered_map(dmap),
                         &out_dom, &in_region);
    *ret = Array<NodeRef>{Map<IterVar, Range>(out_dom),
                          Map<Tensor, Array<Range> >(in_region)};
  });

TVM_REGISTER_API("test.op.MatchTensorizeBody")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    Stage stage = args[0];
    Map<IterVar, Range> out_dom = args[1];
    Map<Tensor, Array<Range> > in_region = args[2];
    TensorIntrin intrin = args[3];
    Map<Var, Range> vrange;
    CHECK(stage->op.as<ComputeOpNode>());
    *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
                              stage,
                              as_unordered_map(out_dom),
                              as_unordered_map(in_region),
                              intrin,
                              &vrange);
  });
}  // namespace tvm