/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * Copyright (c) 2018 by Contributors
 * \file multibox_op.cc
 * \brief Multibox related operators
 */
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>

namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(MultiBoxPriorAttrs);

bool MultiboxPriorRel(const Array<Type>& types,
                      int num_inputs,
                      const Attrs& attrs,
                      const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  const MultiBoxPriorAttrs* param = attrs.as<MultiBoxPriorAttrs>();
  const auto& dshape = data->shape;
  CHECK_EQ(dshape.size(), 4) << "Input data should be 4D: "
      "[batch, channel, height, width]";
  IndexExpr in_height = dshape[2];
  IndexExpr in_width = dshape[3];
  int num_sizes = static_cast<int>(param->sizes.size());
  int num_ratios = static_cast<int>(param->ratios.size());

  // since input sizes are same in each batch, we could share MultiBoxPrior
  std::vector<IndexExpr> oshape(
    {1, in_height * in_width * (num_sizes + num_ratios - 1), 4});

  // assign output type
  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
  return true;
}


Expr MakeMultiBoxPrior(Expr data,
                       Array<IndexExpr> sizes,
                       Array<IndexExpr> ratios,
                       Array<IndexExpr> steps,
                       Array<IndexExpr> offsets,
                       bool clip) {
  auto attrs = make_node<MultiBoxPriorAttrs>();
  attrs->sizes = std::move(sizes);
  attrs->ratios = std::move(ratios);
  attrs->steps = std::move(steps);
  attrs->offsets = std::move(offsets);
  attrs->clip = clip;
  static const Op& op = Op::Get("vision.multibox_prior");
  return CallNode::make(op, {data}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.vision._make.multibox_prior")
.set_body_typed(MakeMultiBoxPrior);


RELAY_REGISTER_OP("vision.multibox_prior")
.describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios."
)doc" TVM_ADD_FILELINE)
.set_attrs_type<MultiBoxPriorAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(5)
.add_type_rel("MultiBoxPrior", MultiboxPriorRel);

TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs);

bool MultiBoxTransformLocRel(const Array<Type>& types,
                             int num_inputs,
                             const Attrs& attrs,
                             const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 4);

  const auto* cls_prob = types[0].as<TensorTypeNode>();
  const auto* loc_pred = types[1].as<TensorTypeNode>();
  const auto* anchor = types[2].as<TensorTypeNode>();
  CHECK(cls_prob != nullptr && loc_pred != nullptr && anchor != nullptr);

  const auto& cls_shape = cls_prob->shape;
  const auto& loc_shape = loc_pred->shape;
  const auto& anchor_shape = anchor->shape;

  CHECK_EQ(cls_shape.size(), 3U)
      << "The dimension of class probability should be 3, but received "
      << cls_shape.size();
  CHECK_EQ(loc_shape.size(), 2U)
      << "The dimension of location prediction should be 2, but received "
      << loc_shape.size();
  CHECK_EQ(anchor_shape.size(), 3U)
      << "The dimension of anchor should be 3, but received "
      << anchor_shape.size();

  CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1]))
      << "Number of anchors mismatch found";
  CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1]))
      << "# anchors mismatch with # loc.";
  CHECK(reporter->Assert(anchor_shape[1] > 0)) << "Number of anchors must > 0.";
  CHECK(reporter->AssertEQ(anchor_shape[2], 4));

  std::vector<IndexExpr> oshape0({cls_shape[0], anchor_shape[1], 6});
  std::vector<IndexExpr> oshape1({cls_shape[0]});
  std::vector<Type> fields;
  fields.push_back(TensorTypeNode::make(oshape0, cls_prob->dtype));
  fields.push_back(TensorTypeNode::make(oshape1, Int(32)));

  // assign output type
  reporter->Assign(types[3], TupleTypeNode::make(Array<Type>(fields)));
  return true;
}

Expr MakeMultiBoxTransformLoc(Expr cls_prob,
                              Expr loc_pred,
                              Expr anchor,
                              bool clip,
                              double threshold,
                              Array<IndexExpr> variances) {
  auto attrs = make_node<MultiBoxTransformLocAttrs>();
  attrs->clip = std::move(clip);
  attrs->threshold = std::move(threshold);
  attrs->variances = std::move(variances);
  static const Op& op = Op::Get("vision.multibox_transform_loc");
  return CallNode::make(op, {cls_prob, loc_pred, anchor}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.vision._make.multibox_transform_loc")
.set_body_typed(MakeMultiBoxTransformLoc);

RELAY_REGISTER_OP("vision.multibox_transform_loc")
.describe(R"doc("Location transformation for multibox detection."
)doc" TVM_ADD_FILELINE)
.set_attrs_type<MultiBoxTransformLocAttrs>()
.set_num_inputs(3)
.add_argument("cls_prob", "Tensor", "Class probabilities.")
.add_argument("loc_pred", "Tensor", "Location regression predictions.")
.add_argument("anchor", "Tensor", "Multibox prior anchor boxes")
.add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel)
.set_support_level(5);

}  // namespace relay
}  // namespace tvm