sort.cc 2.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use standard C library call.
 */

#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dlpack/dlpack.h>
#include <algorithm>
#include <vector>

namespace tvm {
namespace contrib {

using namespace runtime;

template<typename DType>
bool CompareAscend(const std::pair<int32_t, DType>& lhs,
                   const std::pair<int32_t, DType>& rhs) {
  return lhs.second < rhs.second;
}

template<typename DType>
bool CompareDescend(const std::pair<int32_t, DType>& lhs,
                    const std::pair<int32_t, DType>& rhs) {
  return lhs.second > rhs.second;
}


// Argsort implemented C library sort.
// Return indices of sorted tensor.
// By default, the last axis will be used to sort.
// sort_num specify the number of elements to be sorted.
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
// and sort axis is dk. sort_num should have dimension of
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  DLTensor *input = args[0];
  DLTensor *sort_num = args[1];
  DLTensor *output = args[2];
  int32_t axis = args[3];
  bool is_descend = args[4];

  auto dtype = input->dtype;
  auto data_ptr = static_cast<float *>(input->data);
  auto sort_num_ptr = static_cast<int32_t *>(sort_num->data);
  std::vector<std::pair<int32_t, float>> sorter;
  int64_t axis_mul_before = 1;
  int64_t axis_mul_after = 1;

  if (axis < 0) {
    axis = input->ndim + axis;
  }

  // Currently only supports input dtype to be float32.
  CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
      "to be float32.";
  CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
      "to be float32.";
  CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
      "input ndim " << input->ndim;

  for (int i = 0; i < input->ndim; ++i) {
    if (i < axis) {
      axis_mul_before *= input->shape[i];
    } else if (i > axis) {
      axis_mul_after *= input->shape[i];
    }
  }

  for (int64_t i = 0 ; i < axis_mul_before; ++i) {
    for (int64_t j = 0 ; j < axis_mul_after; ++j) {
      sorter.clear();
      int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j);
      int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
      for (int64_t k = 0; k < current_sort_num; ++k) {
        int64_t full_idx = base_idx + k * axis_mul_after;
        sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
      }
      std::stable_sort(sorter.begin(), sorter.end(),
                       is_descend ? CompareDescend<float>
                                  : CompareAscend<float>);
      for (int32_t k = 0; k < input->shape[axis]; ++k) {
        *(static_cast<int32_t *>(output->data) + base_idx + k * axis_mul_after)
          = k < sorter.size() ? sorter[k].first : k;
      }
    }
  }
});

}  // namespace contrib
}  // namespace tvm