/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file dilation2d.cc
 * \brief Morphological dilation operator
 */
#include <tvm/tir/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h>
#include "../op_common.h"

namespace tvm {
namespace relay {

// relay.image.dilation2d
TVM_REGISTER_NODE_TYPE(Dilation2DAttrs);

template<typename T>
Array<Array<Layout> > Dilation2DInferCorrectLayout(
    const Attrs& attrs,
    const Array<Layout>& new_in_layouts,
    const Array<Layout>& old_in_layouts,
    const Array<tvm::relay::Type> &old_in_types) {
  const T* params = attrs.as<T>();

  return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
                               {params->data_layout}};
}

// Positional relay function to create dilation2d operator
// used by frontend FFI.
Expr MakeDilation2D(Expr data,
                    Expr weight,
                    Array<IndexExpr> strides,
                    Array<IndexExpr> padding,
                    Array<IndexExpr> dilations,
                    std::string data_layout,
                    std::string kernel_layout,
                    DataType out_dtype) {
  auto attrs = make_object<Dilation2DAttrs>();
  attrs->strides = std::move(strides);
  attrs->padding = std::move(padding);
  attrs->dilations = std::move(dilations);
  attrs->data_layout = std::move(data_layout);
  attrs->kernel_layout = std::move(kernel_layout);
  attrs->out_dtype = std::move(out_dtype);
  static const Op& op = Op::Get("image.dilation2d");
  return Call(op, {data, weight}, Attrs(attrs), {});
}

template <typename AttrType>
bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
               const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  const auto* weight = types[1].as<TensorTypeNode>();
  if (data == nullptr) return false;
  static const Layout kNCHW("NCHW");
  static const Layout kOIHW("IHW");

  const AttrType* param = attrs.as<AttrType>();
  CHECK(param != nullptr);
  const Layout in_layout(param->data_layout);
  const Layout kernel_layout(param->kernel_layout);

  const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
  CHECK(trans_in_layout.defined())
      << "Dilation2D only support input layouts that are convertible from NCHW."
      << " But got " << in_layout;

  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
  CHECK(trans_kernel_layout.defined())
      << "Dilation2D only support kernel layouts that are convertible from OIHW."
      << " But got " << kernel_layout;

  Layout out_layout(param->data_layout);
  const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
  CHECK(trans_out_layout.defined())
      << "Dilation2D only support output layouts that are convertible from NCHW."
      << " But got " << out_layout;

  Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);

  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;

  // use weight to infer the conv shape.
  if (weight == nullptr) return false;
  auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
  channels = wshape[0];

  dilated_ksize_y = 1 + (wshape[1] - 1) * param->dilations[0];
  dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilations[1];

  // dilation
  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
  IndexExpr pad_h, pad_w;
  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
  if (!dshape_nchw[2].as<tir::AnyNode>()) {
    oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
                           param->strides[0]) + 1);
  } else {
    oshape.Set(2, dshape_nchw[2]);
  }

  if (!dshape_nchw[3].as<tir::AnyNode>()) {
    oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
                           param->strides[1]) + 1);
  } else {
    oshape.Set(3, dshape_nchw[3]);
  }

  DataType out_dtype = param->out_dtype;
  if (out_dtype.bits() == 0) {
    out_dtype = data->dtype;
  }
  oshape = trans_out_layout.BackwardShape(oshape);
  // assign output type
  reporter->Assign(types[2], TensorType(oshape, out_dtype));
  return true;
}

TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d")
.set_body_typed(MakeDilation2D);


RELAY_REGISTER_OP("image.dilation2d")
.describe(R"code(Computes grayscale dilation of 4D input and 3D filter.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
            (batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (in_channels, height, width)
- **out**:  This depends on the `layout` parameter. Output is 4D array of shape
            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<Dilation2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Dilation2D", Dilation2DRel<Dilation2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
         Dilation2DInferCorrectLayout<Dilation2DAttrs>);

}  // namespace relay
}  // namespace tvm