/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \brief Tensor Compute Op.
 * \file tensor_compute_op.cc
 */
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <unordered_set>
#include "./op_util.h"
#include "./compute_op.h"
#include "../../arith/compute_expr.h"

namespace tvm {
namespace te {
using namespace tir;
// TensorComputeOpNode
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const TensorComputeOpNode*>(node.get());
    p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
  });

TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);

int TensorComputeOpNode::num_outputs() const {
  return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
}

DataType TensorComputeOpNode::output_dtype(size_t i) const {
  return this->intrin->buffers[this->inputs.size() + i]->dtype;
}

Operation TensorComputeOpNode::make(std::string name,
                                    std::string tag,
                                    Array<IterVar> axis,
                                    Array<IterVar> reduce_axis,
                                    int schedulable_ndim,
                                    TensorIntrin intrin,
                                    Array<Tensor> tensors,
                                    Array<Region> regions,
                                    Array<PrimExpr> scalar_inputs) {
  auto n = make_object<TensorComputeOpNode>();
  n->name = std::move(name);
  n->tag = std::move(tag);
  n->axis = std::move(axis);
  n->reduce_axis = std::move(reduce_axis);
  n->schedulable_ndim = std::move(schedulable_ndim);
  n->intrin = std::move(intrin);
  n->inputs = std::move(tensors);
  n->input_regions = std::move(regions);
  n->scalar_inputs = std::move(scalar_inputs);
  return Operation(n);
}

TVM_REGISTER_GLOBAL("te.TensorComputeOp")
.set_body_typed(TensorComputeOpNode::make);


Array<Tensor> TensorComputeOpNode::InputTensors() const {
  return inputs;
}

Operation TensorComputeOpNode::ReplaceInputs(
    const Operation& self,
    const std::unordered_map<Tensor, Tensor>& rmap) const {
  CHECK_EQ(self.operator->(), this);
  auto n = make_object<TensorComputeOpNode>(*this);
  auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
  intrin->body = ReplaceTensor(this->intrin->body, rmap);
  if (intrin->reduce_init.defined()) {
    intrin->reduce_init = ReplaceTensor(this->intrin->reduce_init, rmap);
  }
  if (intrin->reduce_update.defined()) {
    intrin->reduce_update = ReplaceTensor(this->intrin->reduce_update, rmap);
  }
  for (size_t i = 0; i < n->inputs.size(); ++i) {
    Tensor t = n->inputs[i];
    if (rmap.count(t)) {
      n->inputs.Set(i, rmap.at(t));
    }
  }

  if (intrin->body.same_as(n->intrin->body) &&
      intrin->reduce_init.same_as(n->intrin->reduce_init) &&
      intrin->reduce_update.same_as(n->intrin->reduce_update) &&
      inputs.same_as(n->inputs)) {
    return self;
  } else {
    n->intrin = TensorIntrin(intrin);
    return Operation(n);
  }
}

void TensorComputeOpNode::PropBoundToInputs(
    const Operation& self,
    arith::Analyzer* analyzer,
    const std::unordered_map<const VarNode*, IntSet>& dom_map,
    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
  for (size_t i = 0; i < this->inputs.size(); ++i) {
    Tensor t = this->inputs[i];
    Region region = input_regions[i];

    auto it = out_dom_map->find(t);
    if (it == out_dom_map->end()) continue;
    TensorDom& dom = it->second;
    for (size_t j = 0; j < t.ndim(); ++j) {
      dom.data[j].emplace_back(EvalSet(region[j], dom_map));
    }
  }
}

size_t TensorComputeOpNode::num_schedulable_dims() const {
  return schedulable_ndim;
}

Stmt TensorComputeOpNode::BuildProvide(
    const Stage& stage,
    const std::unordered_map<IterVar, Range>& dom_map,
    bool debug_keep_trivial_loop) const {
  CHECK_EQ(stage->op.operator->(), this);

  // Start bind data.
  Stmt nop = EvaluateNode::make(0);
  std::vector<Stmt> input_bind_nest, output_bind_nest;
  Array<Tensor> inputs = this->InputTensors();

  // input binding
  size_t num_inputs = inputs.size();
  for (size_t i = 0; i < num_inputs; ++i) {
    Tensor tensor = inputs[i];
    Region region = this->input_regions[i];
    Buffer buffer = this->intrin->buffers[i];
    Array<ObjectRef> bind_spec{buffer, tensor};

    Array<PrimExpr> tuple;
    for (size_t i = 0; i < region.size(); ++i) {
      tuple.push_back(region[i]->min);
      tuple.push_back(region[i]->extent);
    }
    input_bind_nest.emplace_back(AttrStmtNode::make(
        bind_spec, tir::attr::buffer_bind_scope,
        CallNode::make(DataType::Handle(),
                       tir::intrinsic::tvm_tuple,
                       tuple, CallNode::Intrinsic), nop));
  }

  // output binding
  for (int i = 0; i < this->num_outputs(); ++i) {
    Tensor tensor = stage->op.output(i);
    Buffer buffer = this->intrin->buffers[num_inputs + i];
    Array<ObjectRef> bind_spec{buffer, tensor};

    Array<PrimExpr> tuple;
    for (size_t i = 0; i < this->axis.size(); ++i) {
      auto ivar = this->axis[i];
      if (i < static_cast<size_t>(this->schedulable_ndim)) {
        tuple.push_back(ivar->var);
        tuple.push_back(1);
      } else {
        Range dom = ivar->dom;
        tuple.push_back(dom->min);
        tuple.push_back(dom->extent);
      }
    }

    output_bind_nest.emplace_back(AttrStmtNode::make(
        bind_spec, tir::attr::buffer_bind_scope,
        CallNode::make(DataType::Handle(),
                       tir::intrinsic::tvm_tuple,
                       tuple, CallNode::Intrinsic), nop));
  }

  // Check variable remap
  std::unordered_map<const VarNode*, PrimExpr> vmap;
  tir::ArgBinder binder(&vmap);

  // Map the expressions passed in the call to the TensorIntrin, to the placeholder
  // variables
  Array<PrimExpr> user_expr = this->scalar_inputs;
  Array<Var> scalar_params = this->intrin->scalar_params;
  Array<PrimExpr> sp_expr;
  for (auto sp : scalar_params) {
    PrimExpr esp = sp;
    sp_expr.push_back(esp);
  }
  CHECK_EQ(sp_expr.size(), user_expr.size());
  // TODO(jdavies-huawei): what name should be used here?
  binder.BindArray(sp_expr, user_expr, this->name);

  size_t tloc = stage->leaf_iter_vars.size();
  ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);

  if (this->reduce_axis.size() == 0) {
    std::vector<std::vector<Stmt> > nest(
        n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
    nest.emplace_back(MakeIfNest(n.main_predicates));
    CHECK_EQ(n.init_predicates.size(), 0U);
    CHECK(this->intrin->body.defined())
        << "Normal store op for intrin " << this << " is not defined";
    Stmt body = MergeNest(output_bind_nest, this->intrin->body);
    body = MergeNest(input_bind_nest, body);
    body = tir::Substitute(body, vmap);
    body = MergeNest(binder.asserts(), body);
    body = te::Substitute(body, n.main_vmap);
    Stmt ret =  MergeNest(nest, body);
    return ret;
  } else {
    // Need to split reduction
    CHECK(this->intrin->reduce_update.defined())
        << "Reduction update op is not defined";
    // Need init and update steps
    CHECK_NE(this->reduce_axis.size(), 0U);
    std::vector<std::vector<Stmt> > common(
        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
    std::vector<std::vector<Stmt> > update_nest(
        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
    update_nest.emplace_back(MakeIfNest(n.main_predicates));

    if (this->intrin->reduce_init.defined()) {
      // init nest
      std::vector<std::vector<Stmt> > init_nest(
          n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
      init_nest.emplace_back(MakeIfNest(n.init_predicates));
      Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
      init = te::Substitute(init, n.init_vmap);
      init = MergeNest(init_nest, init);
      // The update
      Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
      update = MergeNest(input_bind_nest, update);
      update = tir::Substitute(update, vmap);
      update = MergeNest(binder.asserts(), update);
      update = te::Substitute(update, n.main_vmap);
      update = MergeNest(update_nest, update);
      return MergeNest(common, SeqStmt::Flatten(init, update));
    } else {
      // When init op is not available, use body op for reset in the first iter.
      CHECK(this->intrin->body.defined())
          << "Normal body op is not defined";
      Stmt update = TransformUpdate(stage, dom_map, n,
                                    this->intrin->body,
                                    this->intrin->reduce_update);
      update = MergeNest(output_bind_nest, update);
      update = MergeNest(input_bind_nest, update);
      update = tir::Substitute(update, vmap);
      update = MergeNest(binder.asserts(), update);
      update = te::Substitute(update, n.main_vmap);
      update = MergeNest(update_nest, update);
      return MergeNest(common, update);
    }
  }
}
}  // namespace te
}  // namespace tvm