/*!
 *  Copyright (c) 2018 by Contributors
 * \file pad.cc
 * \brief Implementation of operator pad
 */
#include <tvm/ir_operator.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <topi/nn.h>
#include <vector>
#include "../layout.h"
#include "../op_common.h"

namespace tvm {
namespace relay {

// relay.nn.pad
TVM_REGISTER_NODE_TYPE(PadAttrs);

bool PadRel(const Array<Type>& types,
            int num_inputs,
            const Attrs& attrs,
            const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  const PadAttrs* param = attrs.as<PadAttrs>();
  CHECK(param != nullptr);

  // check that pad widths match lengths
  CHECK(data->shape.size() == param->pad_width.size())
    << "There should be as many pad width pairs as shape dimensions "
    << "but the shape has " << data->shape.size() << " dimensions "
    << "and there are " << param->pad_width.size() << " pad width pairs.";

  // each pad width element should be a pair of positive integers
  std::vector<IndexExpr> oshape;
  for (size_t i = 0; i < param->pad_width.size(); i++) {
    CHECK(param->pad_width[i].size() == 2)
      << "Each pad width element should be a pair but at index " << i
      << " there are " << param->pad_width[i].size() << " elements.";

    auto width1 = as_const_int(param->pad_width[i][0]);
    auto width2 = as_const_int(param->pad_width[i][1]);
    CHECK(width1 != nullptr);
    CHECK(width2 != nullptr);

    CHECK(*width1 >= 0)
      << "Param width elements should be positive but first pad width at "
      << "index " << i << " is " << *width1 << ".";
    CHECK(*width2 >= 0)
      << "Param width elements should be positive but first pad width at "
      << "index " << i << " is " << *width2 << ".";

    auto padding = make_const(data->shape[i].type(), *width1 + *width2);
    oshape.push_back(data->shape[i] + padding);
  }

  reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
                                                  data->dtype));
  return true;
}

Array<Tensor> PadCompute(const Attrs& attrs,
                         const Array<Tensor>& inputs,
                         const Type& out_type,
                         const Target& target) {
  const auto* param = attrs.as<PadAttrs>();
  CHECK(param != nullptr);

  auto pad_width = param->pad_width;
  CHECK(pad_width.size() == inputs[0].ndim() &&
    pad_width[0].size() == 2)
    << "Illegal pad_width";
  Array<IndexExpr> pad_before;
  for (size_t i = 0; i < pad_width.size(); ++i) {
    pad_before.push_back(pad_width[i][0]);
  }
  Array<IndexExpr> pad_after;
  for (size_t i = 0; i < pad_width.size(); ++i) {
    pad_after.push_back(pad_width[i][1]);
  }
  const auto* out_ttype = out_type.as<TensorTypeNode>();
  return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after,
                                  tvm::make_const(out_ttype->dtype, param->pad_value)) };
}

// Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
  auto attrs = make_node<PadAttrs>();
  attrs->pad_value = pad_value;
  attrs->pad_width = std::move(pad_width);
  static const Op& op = Op::Get("nn.pad");
  return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.pad")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
    runtime::detail::unpack_call<Expr, 3>(MakePad, args, rv);
  });

RELAY_REGISTER_OP("nn.pad")
.describe(R"code(Pad for n-D tensor.

)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.PadAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("Pad", PadRel)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>("FTVMCompute", PadCompute);

}  // namespace relay
}  // namespace tvm