/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use external cudnn utils function
 */
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/device_api.h>
#include "cudnn_utils.h"

namespace tvm {
namespace contrib {

using namespace runtime;


TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  int mode = args[0];
  int format = args[1];
  int algo = args[2];
  int pad_h = args[3];
  int pad_w = args[4];
  int stride_h = args[5];
  int stride_w = args[6];
  int dilation_h = args[7];
  int dilation_w = args[8];
  DLTensor *x = args[9];
  DLTensor *w = args[10];
  DLTensor *y = args[11];
  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
  // Set Mode
  entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
  // Set Format
  entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
  // Set Algo
  entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
  // Set Ctx
  entry_ptr->conv_entry.ctx = x->ctx;
  // Set Data Type
  entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
  // Set Desc
  CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
                                             pad_h,
                                             pad_w,
                                             stride_h,
                                             stride_w,
                                             dilation_h,
                                             dilation_w,
                                             entry_ptr->conv_entry.mode,
                                             entry_ptr->conv_entry.data_type));
  // Set Filter
  CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
                                        entry_ptr->conv_entry.data_type,
                                        CUDNN_TENSOR_NCHW,
                                        static_cast<int>(w->shape[0]),
                                        static_cast<int>(w->shape[1]),
                                        static_cast<int>(w->shape[2]),
                                        static_cast<int>(w->shape[3])));
  // Set Input
  CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
                                        entry_ptr->conv_entry.tensor_format,
                                        entry_ptr->conv_entry.data_type,
                                        static_cast<int>(x->shape[0]),
                                        static_cast<int>(x->shape[1]),
                                        static_cast<int>(x->shape[2]),
                                        static_cast<int>(x->shape[3])));
  // Set Output
  CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc,
                                        entry_ptr->conv_entry.tensor_format,
                                        entry_ptr->conv_entry.data_type,
                                        static_cast<int>(y->shape[0]),
                                        static_cast<int>(y->shape[1]),
                                        static_cast<int>(y->shape[2]),
                                        static_cast<int>(y->shape[3])));
  // Set workspace
  size_t workspace_size = 0;
  CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle,
                                                     entry_ptr->conv_entry.input_desc,
                                                     entry_ptr->conv_entry.filter_desc,
                                                     entry_ptr->conv_entry.conv_desc,
                                                     entry_ptr->conv_entry.output_desc,
                                                     entry_ptr->conv_entry.fwd_algo,
                                                     &workspace_size));
  entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
  CUDNN_CALL(cudnnConvolutionForward(entry_ptr->handle,
                                     CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
                                     entry_ptr->conv_entry.input_desc,
                                     x->data,
                                     entry_ptr->conv_entry.filter_desc,
                                     w->data,
                                     entry_ptr->conv_entry.conv_desc,
                                     entry_ptr->conv_entry.fwd_algo,
                                     entry_ptr->conv_entry.workspace,
                                     workspace_size,
                                     CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
                                     entry_ptr->conv_entry.output_desc,
                                     y->data));
});


TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
  int format = args[0];
  int pad_h = args[1];
  int pad_w = args[2];
  int stride_h = args[3];
  int stride_w = args[4];
  int dilation_h = args[5];
  int dilation_w = args[6];
  int x_dim0 = args[7];
  int x_dim1 = args[8];
  int x_dim2 = args[9];
  int x_dim3 = args[10];
  int w_dim0 = args[11];
  int w_dim1 = args[12];
  int w_dim2 = args[12];
  int w_dim3 = args[14];
  void *out_shape = args[15];
  // Set Format
  entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
  // conv desc
  CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
                                             pad_h,
                                             pad_w,
                                             stride_h,
                                             stride_w,
                                             dilation_h,
                                             dilation_w,
                                             CUDNN_CROSS_CORRELATION,
                                             entry_ptr->conv_entry.data_type));
  // input desc
  CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
                                        entry_ptr->conv_entry.tensor_format,
                                        CUDNN_DATA_FLOAT,
                                        x_dim0,
                                        x_dim1,
                                        x_dim2,
                                        x_dim3));
  // filter desc
  CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
                                        CUDNN_DATA_FLOAT,
                                        CUDNN_TENSOR_NCHW,
                                        w_dim0,
                                        w_dim1,
                                        w_dim2,
                                        w_dim3));

  CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(entry_ptr->conv_entry.conv_desc,
                                                   entry_ptr->conv_entry.input_desc,
                                                   entry_ptr->conv_entry.filter_desc,
                                                   static_cast<int*>(out_shape),
                                                   static_cast<int*>(out_shape) + 1,
                                                   static_cast<int*>(out_shape) + 2,
                                                   static_cast<int*>(out_shape) + 3));
  });

}  // namespace contrib
}  // namespace tvm