/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file Use external cudnn utils function
 */
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/device_api.h>
#include "cudnn_utils.h"

namespace tvm {
namespace contrib {

using namespace runtime;

void ConvolutionForward(
  int mode,
  int format,
  int algo,
  int dims,
  const int pad[],
  const int stride[],
  const int dilation[],
  DLTensor* x,
  DLTensor* w,
  DLTensor* y,
  const std::string& conv_dtype) {
  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
  // Set Mode
  entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
  // Set Format
  entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
  // Set Algo
  entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
  // Set Ctx
  entry_ptr->conv_entry.ctx = x->ctx;
  // Set Data Type
  entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype));
  cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
  // Dims includes N and C
  int full_dims = dims + 2;

  std::vector<int> dim(full_dims);
  std::vector<int> tensor_stride(full_dims);

  // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error
  // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int
  if (dims == 2) {
  // Set Desc
    CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
                                               pad[0],
                                               pad[1],
                                               stride[0],
                                               stride[1],
                                               dilation[0],
                                               dilation[1],
                                               entry_ptr->conv_entry.mode,
                                               entry_ptr->conv_entry.data_type));
    int ni, ci, hi, wi;
    if (entry_ptr->conv_entry.tensor_format ==  CUDNN_TENSOR_NHWC) {
      ni = 0;
      ci = 3;
      hi = 1;
      wi = 2;
    } else {
      ni = 0;
      ci = 1;
      hi = 2;
      wi = 3;
    }

    // Set Filter
    CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
                                          data_type,
                                          entry_ptr->conv_entry.tensor_format,
                                          static_cast<int>(w->shape[ni]),
                                          static_cast<int>(w->shape[ci]),
                                          static_cast<int>(w->shape[hi]),
                                          static_cast<int>(w->shape[wi])));
    // Set Input
    CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
                                          entry_ptr->conv_entry.tensor_format,
                                          data_type,
                                          static_cast<int>(x->shape[ni]),
                                          static_cast<int>(x->shape[ci]),
                                          static_cast<int>(x->shape[hi]),
                                          static_cast<int>(x->shape[wi])));
    // Set Output
    CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc,
                                          entry_ptr->conv_entry.tensor_format,
                                          data_type,
                                          static_cast<int>(y->shape[ni]),
                                          static_cast<int>(y->shape[ci]),
                                          static_cast<int>(y->shape[hi]),
                                          static_cast<int>(y->shape[wi])));
  } else {
    CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
                                               dims,
                                               pad,
                                               stride,
                                               dilation,
                                               entry_ptr->conv_entry.mode,
                                               entry_ptr->conv_entry.data_type));

    // Set Filter
    for (int i = 0; i < full_dims; i++) {
      dim[i] = static_cast<int>(w->shape[i]);
    }
    CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc,
                                          data_type,
                                          entry_ptr->conv_entry.tensor_format,
                                          full_dims,
                                          dim.data()));
    // Set Input
    for (int i = 0; i < full_dims; i++) {
      dim[i] = static_cast<int>(x->shape[i]);
    }
    GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
    CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
                                          data_type,
                                          full_dims,
                                          dim.data(),
                                          tensor_stride.data()));
    // Set Output
    for (int i = 0; i < full_dims; i++) {
      dim[i] = static_cast<int>(y->shape[i]);
    }
    GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
    CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc,
                                          data_type,
                                          full_dims,
                                          dim.data(),
                                          tensor_stride.data()));
  }

  if (cudnnGetVersion() > 7000) {
    CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
  }

  // Set workspace
  size_t workspace_size = 0;
  CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle,
                                                     entry_ptr->conv_entry.input_desc,
                                                     entry_ptr->conv_entry.filter_desc,
                                                     entry_ptr->conv_entry.conv_desc,
                                                     entry_ptr->conv_entry.output_desc,
                                                     entry_ptr->conv_entry.fwd_algo,
                                                     &workspace_size));
  entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
  CUDNN_CALL(cudnnConvolutionForward(entry_ptr->handle,
                                     CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
                                     entry_ptr->conv_entry.input_desc,
                                     x->data,
                                     entry_ptr->conv_entry.filter_desc,
                                     w->data,
                                     entry_ptr->conv_entry.conv_desc,
                                     entry_ptr->conv_entry.fwd_algo,
                                     entry_ptr->conv_entry.workspace,
                                     workspace_size,
                                     CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
                                     entry_ptr->conv_entry.output_desc,
                                     y->data));
}


void OutputShape(
  int format,
  int dims,
  const int pad[],
  const int stride[],
  const int dilation[],
  const int x_dim[],
  const int w_dim[],
  void *out_shape,
  const std::string& data_dtype,
  const std::string& conv_dtype) {
  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();

  // Set Data Type
  entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype));
  cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(data_dtype));
  // Set Format
  entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
  // Dims includes N and C
  int full_dims = dims + 2;

  // conv desc
  CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
                                             dims,
                                             pad,
                                             stride,
                                             dilation,
                                             CUDNN_CROSS_CORRELATION,
                                             entry_ptr->conv_entry.data_type));

  if (dims == 2 && entry_ptr->conv_entry.tensor_format ==  CUDNN_TENSOR_NHWC) {
    // Set Input
    CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
                                          entry_ptr->conv_entry.tensor_format,
                                          data_type,
                                          x_dim[0],
                                          x_dim[3],
                                          x_dim[1],
                                          x_dim[2]));

    // filter desc
    CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
                                          data_type,
                                          entry_ptr->conv_entry.tensor_format,
                                          w_dim[0],
                                          w_dim[3],
                                          w_dim[1],
                                          w_dim[2]));

    CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(entry_ptr->conv_entry.conv_desc,
                                                     entry_ptr->conv_entry.input_desc,
                                                     entry_ptr->conv_entry.filter_desc,
                                                     static_cast<int*>(out_shape),
                                                     static_cast<int*>(out_shape) + 3,
                                                     static_cast<int*>(out_shape) + 1,
                                                     static_cast<int*>(out_shape) + 2));
  } else {
    // Set Input
    std::vector<int> tensor_stride(full_dims);
    GetCudnnStride(full_dims, x_dim, tensor_stride.data());
    CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
                                          data_type,
                                          full_dims,
                                          x_dim,
                                          tensor_stride.data()));
    // filter desc
    CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc,
                                          data_type,
                                          entry_ptr->conv_entry.tensor_format,
                                          full_dims,
                                          w_dim));

    CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(entry_ptr->conv_entry.conv_desc,
                                                     entry_ptr->conv_entry.input_desc,
                                                     entry_ptr->conv_entry.filter_desc,
                                                     full_dims,
                                                     static_cast<int*>(out_shape)));
  }
}


void FindAlgo(
  int format,
  int dims,
  const int pad[],
  const int stride[],
  const int dilation[],
  const int x_dim[],
  const int w_dim[],
  const int y_dim[],
  const std::string& data_dtype,
  const std::string& conv_dtype,
  TVMRetValue *ret) {
  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();

  // Set Data Type
  entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype));
  cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(data_dtype));
  // Set Format
  entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
  // Dims includes N and C
  int full_dims = dims + 2;

  // conv desc
  CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
                                             dims,
                                             pad,
                                             stride,
                                             dilation,
                                             CUDNN_CROSS_CORRELATION,
                                             entry_ptr->conv_entry.data_type));

  std::vector<int> tensor_stride(full_dims);
  // input desc
  GetCudnnStride(full_dims, x_dim, tensor_stride.data());
  CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
                                        data_type,
                                        full_dims,
                                        x_dim,
                                        tensor_stride.data()));
  // filter desc
  CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc,
                                        data_type,
                                        entry_ptr->conv_entry.tensor_format,
                                        full_dims,
                                        w_dim));

  // output desc
  GetCudnnStride(full_dims, y_dim, tensor_stride.data());
  CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc,
                                        data_type,
                                        full_dims,
                                        y_dim,
                                        tensor_stride.data()));
  if (cudnnGetVersion() > 7000) {
    CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
  }

  int returned_algo_count = 0;
  cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
  CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(entry_ptr->handle,
                                                  entry_ptr->conv_entry.input_desc,
                                                  entry_ptr->conv_entry.filter_desc,
                                                  entry_ptr->conv_entry.conv_desc,
                                                  entry_ptr->conv_entry.output_desc,
                                                  CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
                                                  &returned_algo_count,
                                                  perf_results));

  const std::vector<std::string> fwd_algo_names{
      "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
      "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
      "CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
      "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
      "CUDNN_CONVOLUTION_FWD_ALGO_FFT",
      "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
      "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
      "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"
  };

  auto best_algo = perf_results[0].algo;
  LOG(INFO) << "\tCUDNN Found " << returned_algo_count
            << " fwd algorithms, choosing " << fwd_algo_names[best_algo];
  for (int i = 0; i < returned_algo_count; ++i) {
    LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
              << " - time: " << perf_results[i].time << " ms"
              << ", Memory: " << perf_results[i].memory;
  }

  ret[0] = best_algo;
}


TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  int mode = args[0];
  int format = args[1];
  int algo = args[2];
  int pad_v[2], stride_v[2], dilation_v[2];
  for (int i = 0; i < 2; i++) {
      pad_v[i] = args[3 + i];
      stride_v[i] = args[5 + i];
      dilation_v[i] = args[7 + i];
  }
  DLTensor* x = args[9];
  DLTensor* w = args[10];
  DLTensor* y = args[11];
  std::string conv_dtype = args[12];

  ConvolutionForward(mode, format, algo, 2, pad_v, stride_v, dilation_v, x, w, y, conv_dtype);
});


TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  int mode = args[0];
  int format = args[1];
  int algo = args[2];
  int pad_v[3], stride_v[3], dilation_v[3];
  for (int i = 0; i < 3; i++) {
      pad_v[i] = args[3 + i];
      stride_v[i] = args[6 + i];
      dilation_v[i] = args[9 + i];
  }
  DLTensor *x = args[12];
  DLTensor *w = args[13];
  DLTensor *y = args[14];
  std::string conv_dtype = args[15];

  ConvolutionForward(mode, format, algo, 3, pad_v, stride_v, dilation_v, x, w, y,
                     conv_dtype);
});


TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  int format = args[0];
  int dims = args[1];
  int* pad = static_cast<int*>(static_cast<void*>(args[2]));
  int* stride = static_cast<int*>(static_cast<void*>(args[3]));
  int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
  int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
  int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
  void* out_shape = args[7];
  std::string data_dtype = args[8];
  std::string conv_dtype = args[9];

  OutputShape(format, dims, pad, stride, dilation, x_dim,
              w_dim, out_shape, data_dtype, conv_dtype);
});


TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  int format = args[0];
  int dims = args[1];
  int* pad = static_cast<int*>(static_cast<void*>(args[2]));
  int* stride = static_cast<int*>(static_cast<void*>(args[3]));
  int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
  int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
  int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
  int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
  std::string data_dtype = args[8];
  std::string conv_dtype = args[9];

  FindAlgo(format, dims, pad, stride, dilation, x_dim,
           w_dim, y_dim, data_dtype, conv_dtype, ret);
});

}  // namespace contrib
}  // namespace tvm