/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file Use standard C library call. */ #include <tvm/runtime/registry.h> #include <tvm/runtime/util.h> #include <dlpack/dlpack.h> #include <algorithm> #include <vector> namespace tvm { namespace contrib { using namespace runtime; template<typename DType> bool CompareAscend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) { return lhs.second < rhs.second; } template<typename DType> bool CompareDescend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) { return lhs.second > rhs.second; } // Argsort implemented C library sort for nms. // Return indices of sorted tensor. // By default, the last axis will be used to sort. // sort_num specify the number of elements to be sorted. // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") .set_body([](TVMArgs args, TVMRetValue *ret) { DLTensor *input = args[0]; DLTensor *sort_num = args[1]; DLTensor *output = args[2]; int32_t axis = args[3]; bool is_ascend = args[4]; auto dtype = input->dtype; auto data_ptr = static_cast<float *>(input->data); auto sort_num_ptr = static_cast<int32_t *>(sort_num->data); std::vector<std::pair<int32_t, float>> sorter; int64_t axis_mul_before = 1; int64_t axis_mul_after = 1; if (axis < 0) { axis = input->ndim + axis; } // Currently only supports input dtype to be float32. CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " "to be float32."; CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " "to be float32."; CHECK_LT(axis, input->ndim) << "Axis out of boundary for " "input ndim " << input->ndim; for (int i = 0; i < input->ndim; ++i) { if (i < axis) { axis_mul_before *= input->shape[i]; } else if (i > axis) { axis_mul_after *= input->shape[i]; } } for (int64_t i = 0 ; i < axis_mul_before; ++i) { for (int64_t j = 0 ; j < axis_mul_after; ++j) { sorter.clear(); int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j); int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; for (int64_t k = 0; k < current_sort_num; ++k) { int64_t full_idx = base_idx + k * axis_mul_after; sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); } if (is_ascend) { std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>); } else { std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>); } for (int32_t k = 0; k < input->shape[axis]; ++k) { *(static_cast<int32_t *>(output->data) + base_idx + k * axis_mul_after) = k < static_cast<int32_t>(sorter.size()) ? sorter[k].first : k; } } } }); template<typename DataType, typename OutType> void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { auto data_ptr = static_cast<DataType *>(input->data); auto out_ptr = static_cast<OutType *>(output->data); std::vector<std::pair<int64_t, DataType> > sorter; int axis_mul_before = 1; int axis_mul_after = 1; for (int i = 0; i < input->ndim; ++i) { if (i < axis) { axis_mul_before *= input->shape[i]; } else if (i > axis) { axis_mul_after *= input->shape[i]; } } for (int i = 0 ; i < axis_mul_before; ++i) { for (int j = 0 ; j < axis_mul_after; ++j) { sorter.clear(); int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; for (int64_t k = 0; k < input->shape[axis]; ++k) { int64_t full_idx = base_idx + k * axis_mul_after; sorter.emplace_back(std::make_pair(k, data_ptr[full_idx])); } if (is_ascend) { std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>); } else { std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>); } for (int64_t k = 0; k < input->shape[axis]; ++k) { out_ptr[base_idx + k * axis_mul_after] = static_cast<OutType>(sorter[k].first); } } } } // Argsort implemented C library sort. // Return indices of sorted tensor. // By default, the last axis will be used to sort. // sort_num specify the number of elements to be sorted. // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") .set_body([](TVMArgs args, TVMRetValue *ret) { DLTensor *input = args[0]; DLTensor *output = args[1]; int32_t axis = args[2]; bool is_ascend = args[3]; if (axis < 0) { axis = input->ndim + axis; } CHECK_LT(axis, input->ndim) << "Axis out of boundary for " "input ndim " << input->ndim; auto data_dtype = TVMType2String(input->dtype); auto out_dtype = TVMType2String(output->dtype); if (data_dtype == "float32") { if (out_dtype == "int32") { argsort<float, int32_t>(input, output, axis, is_ascend); } else if (out_dtype == "int64") { argsort<float, int64_t>(input, output, axis, is_ascend); } else if (out_dtype == "float32") { argsort<float, float>(input, output, axis, is_ascend); } else if (out_dtype == "float64") { argsort<float, double>(input, output, axis, is_ascend); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "float64") { if (out_dtype == "int32") { argsort<double, int32_t>(input, output, axis, is_ascend); } else if (out_dtype == "int64") { argsort<double, int64_t>(input, output, axis, is_ascend); } else if (out_dtype == "float32") { argsort<double, float>(input, output, axis, is_ascend); } else if (out_dtype == "float64") { argsort<double, double>(input, output, axis, is_ascend); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int32") { if (out_dtype == "int32") { argsort<int32_t, int32_t>(input, output, axis, is_ascend); } else if (out_dtype == "int64") { argsort<int32_t, int64_t>(input, output, axis, is_ascend); } else if (out_dtype == "float32") { argsort<int32_t, float>(input, output, axis, is_ascend); } else if (out_dtype == "float64") { argsort<int32_t, double>(input, output, axis, is_ascend); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int64") { if (out_dtype == "int32") { argsort<int64_t, int32_t>(input, output, axis, is_ascend); } else if (out_dtype == "int64") { argsort<int64_t, int64_t>(input, output, axis, is_ascend); } else if (out_dtype == "float32") { argsort<int64_t, float>(input, output, axis, is_ascend); } else if (out_dtype == "float64") { argsort<int64_t, double>(input, output, axis, is_ascend); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } }); template<typename DataType, typename IndicesType> void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis, bool is_ascend) { DataType* data_ptr = static_cast<DataType *>(input->data); DataType* values_ptr = (out_values == nullptr) ? nullptr : static_cast<DataType *>(out_values->data); IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr : static_cast<IndicesType *>(out_indices->data); std::vector<std::pair<int64_t, DataType> > sorter; int axis_mul_before = 1; int axis_mul_after = 1; for (int i = 0; i < input->ndim; ++i) { if (i < axis) { axis_mul_before *= input->shape[i]; } else if (i > axis) { axis_mul_after *= input->shape[i]; } } if (k < 1) { k = input->shape[axis]; } for (int i = 0 ; i < axis_mul_before; ++i) { for (int j = 0 ; j < axis_mul_after; ++j) { sorter.clear(); int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; int64_t dst_base_idx = i * k * axis_mul_after + j; for (int64_t kk = 0; kk < input->shape[axis]; ++kk) { int64_t full_idx = src_base_idx + kk * axis_mul_after; sorter.emplace_back(std::make_pair(kk, data_ptr[full_idx])); } if (is_ascend) { std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>); } else { std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>); } int64_t cnt = k > 0 ? k : input->shape[axis]; for (int64_t kk = 0; kk < cnt; ++kk) { if (indices_ptr != nullptr) { indices_ptr[dst_base_idx + kk * axis_mul_after] = static_cast<IndicesType>(sorter[kk].first); } if (values_ptr != nullptr) { values_ptr[dst_base_idx + kk * axis_mul_after] = static_cast<DataType>(sorter[kk].second); } } } } } // Argsort implemented C library sort. // Return indices of sorted tensor. // By default, the last axis will be used to sort. // sort_num specify the number of elements to be sorted. // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") .set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* input = args[0]; DLTensor* values_out = nullptr; DLTensor* indices_out = nullptr; int k = args[args.num_args - 4]; int axis = args[args.num_args - 3]; std::string ret_type = args[args.num_args - 2]; bool is_ascend = args[args.num_args - 1]; if (ret_type == "both") { values_out = args[1]; indices_out = args[2]; } else if (ret_type == "values") { values_out = args[1]; } else if (ret_type == "indices") { indices_out = args[1]; } else { LOG(FATAL) << "Unsupported ret type: " << ret_type; } if (axis < 0) { axis = input->ndim + axis; } CHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim; auto data_dtype = TVMType2String(input->dtype); auto out_dtype = (indices_out == nullptr) ? "int64" : TVMType2String(indices_out->dtype); if (data_dtype == "float32") { if (out_dtype == "int32") { topk<float, int32_t>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "int64") { topk<float, int64_t>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "float32") { topk<float, float>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "float64") { topk<float, double>(input, values_out, indices_out, k, axis, is_ascend); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "float64") { if (out_dtype == "int32") { topk<double, int32_t>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "int64") { topk<double, int64_t>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "float32") { topk<double, float>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "float64") { topk<double, double>(input, values_out, indices_out, k, axis, is_ascend); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int32") { if (out_dtype == "int32") { topk<int32_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "int64") { topk<int32_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "float32") { topk<int32_t, float>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "float64") { topk<int32_t, double>(input, values_out, indices_out, k, axis, is_ascend); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int64") { if (out_dtype == "int32") { topk<int64_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "int64") { topk<int64_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "float32") { topk<int64_t, float>(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "float64") { topk<int64_t, double>(input, values_out, indices_out, k, axis, is_ascend); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } }); } // namespace contrib } // namespace tvm