/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file Use external Thrust library call
 */

#include <thrust/device_ptr.h>
#include <thrust/sort.h>

#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>
#include <algorithm>
#include <vector>
#include <functional>

namespace tvm {
namespace contrib {

using namespace runtime;

// Performs sorting along axis -1 and returns both sorted values and indices.
template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
                 DLTensor* out_values,
                 DLTensor* out_indices,
                 bool is_ascend,
                 const std::function<int(int)> &get_sort_len) {
  thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
  thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
  thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));

  int n_values = input->shape[input->ndim - 1];
  int n_iter = 1;
  for (int i = 0; i < input->ndim - 1; ++i) {
    n_iter *= input->shape[i];
  }

  thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);

  for (int i = 0 ; i < n_iter; ++i) {
    n_values = get_sort_len(i);
    thrust::sequence(indices_ptr, indices_ptr + n_values);
    if (is_ascend) {
      thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
    } else {
      thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
                          thrust::greater<DataType>());
    }
    values_ptr += n_values;
    indices_ptr += n_values;
  }
}

void thrust_sort_common(DLTensor* input,
                        DLTensor* values_out,
                        DLTensor* indices_out,
                        bool is_ascend,
                        const std::function<int(int)> &get_sort_len,
                        std::string data_dtype,
                        std::string out_dtype) {
  if (data_dtype == "float32") {
    if (out_dtype == "int32") {
      thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "int64") {
      thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "float32") {
      thrust_sort<float, float>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "float64") {
      thrust_sort<float, double>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else {
      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
    }
  } else if (data_dtype == "float64") {
    if (out_dtype == "int32") {
      thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "int64") {
      thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "float32") {
      thrust_sort<double, float>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "float64") {
      thrust_sort<double, double>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else {
      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
    }
  } else if (data_dtype == "int32") {
    if (out_dtype == "int32") {
      thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "int64") {
      thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "float32") {
      thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "float64") {
      thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else {
      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
    }
  }  else if (data_dtype == "int64") {
    if (out_dtype == "int32") {
      thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "int64") {
      thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "float32") {
      thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else if (out_dtype == "float64") {
      thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
    } else {
      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
    }
  } else {
    LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
  }
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  CHECK_GE(args.num_args, 5);
  DLTensor* input = args[0];
  DLTensor* valid_count = args[1];
  DLTensor* values_out = args[2];
  DLTensor* indices_out = args[3];
  bool is_ascend = args[4];

  auto data_dtype = DLDataType2String(input->dtype);
  auto out_dtype = DLDataType2String(indices_out->dtype);

  thrust::device_ptr<int> valid_count_ptr(static_cast<int *>(valid_count->data));
  auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; };
  thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
                     data_dtype, out_dtype);
});


TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  CHECK_GE(args.num_args, 4);
  DLTensor* input = args[0];
  DLTensor* values_out = args[1];
  DLTensor* indices_out = args[2];
  bool is_ascend = args[3];

  auto data_dtype = DLDataType2String(input->dtype);
  auto out_dtype = DLDataType2String(indices_out->dtype);

  int n_values = input->shape[input->ndim - 1];
  auto get_sort_len = [=](int i) { return n_values; };
  thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
                     data_dtype, out_dtype);
});
}  // namespace contrib
}  // namespace tvm