Unverified Commit 2e8f3a91 by MORITA Kazutaka Committed by GitHub

[TOPI][OP] Use Thrust sort for argsort and topk (#5097)

* [TOPI][OP] Use Thrust sort for argsort and topk

The current GPU sort implementation (odd-even transposition sort) is too slow
when the number of elements is large.  This PR introduces Thrust implementation
of sort which is much faster.

Note that this change requires CMake 3.8 or later since we have to use nvcc to
compile a thrust code.

* cmake: make CUDA optional

* allow .cu file to be into the repository

* pylint fix and cleanup

* require cmake 3.8 only when thrust is enabled

* fix nvcc compiler error when passing -pthread

* add missing include

* add USE_THRUST option in config.cmake

* retrigger CI

* retrigger CI
parent dbd805c1
...@@ -55,6 +55,7 @@ tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none) ...@@ -55,6 +55,7 @@ tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none)
tvm_option(USE_MKLDNN "Build with MKLDNN" OFF) tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF)
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF) tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
tvm_option(USE_THRUST "Build with Thrust" OFF)
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF) tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" OFF) tvm_option(USE_SORT "Build with sort support" OFF)
...@@ -101,9 +102,11 @@ else(MSVC) ...@@ -101,9 +102,11 @@ else(MSVC)
message("Build in Debug mode") message("Build in Debug mode")
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
else() else()
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O2 -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
if (HIDE_PRIVATE_SYMBOLS) if (HIDE_PRIVATE_SYMBOLS)
message(STATUS "Hide private symbols...") message(STATUS "Hide private symbols...")
set(CMAKE_C_FLAGS "-fvisibility=hidden ${CMAKE_C_FLAGS}") set(CMAKE_C_FLAGS "-fvisibility=hidden ${CMAKE_C_FLAGS}")
...@@ -262,6 +265,7 @@ if(NOT MSVC) ...@@ -262,6 +265,7 @@ if(NOT MSVC)
check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14) check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14)
message(STATUS "Build with c++14") message(STATUS "Build with c++14")
set(CMAKE_CXX_FLAGS "-std=c++14 ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "-std=c++14 ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_STANDARD 14)
endif() endif()
add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
......
...@@ -201,3 +201,6 @@ set(USE_VTA_FPGA OFF) ...@@ -201,3 +201,6 @@ set(USE_VTA_FPGA OFF)
# Whether to build the example external runtime module # Whether to build the example external runtime module
set(USE_EXAMPLE_EXT_RUNTIME OFF) set(USE_EXAMPLE_EXT_RUNTIME OFF)
# Whether use Thrust
set(USE_THRUST OFF)
...@@ -55,6 +55,14 @@ if(USE_CUDA) ...@@ -55,6 +55,14 @@ if(USE_CUDA)
endif() endif()
endif(USE_CUBLAS) endif(USE_CUBLAS)
if(USE_THRUST)
message(STATUS "Build with Thrust support")
cmake_minimum_required(VERSION 3.13) # to compile CUDA code
enable_language(CUDA)
file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
endif(USE_THRUST)
else(USE_CUDA) else(USE_CUDA)
list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc) list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc)
endif(USE_CUDA) endif(USE_CUDA)
...@@ -20,6 +20,7 @@ import topi ...@@ -20,6 +20,7 @@ import topi
from tvm.te import SpecializedCondition from tvm.te import SpecializedCondition
from .generic import * from .generic import *
from .. import op as _op from .. import op as _op
from .... import get_global_func
@schedule_injective.register(["cuda", "gpu"]) @schedule_injective.register(["cuda", "gpu"])
def schedule_injective_cuda(attrs, outs, target): def schedule_injective_cuda(attrs, outs, target):
...@@ -328,6 +329,11 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target): ...@@ -328,6 +329,11 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_argsort(topi.cuda.argsort), wrap_compute_argsort(topi.cuda.argsort),
wrap_topi_schedule(topi.cuda.schedule_argsort), wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.cuda") name="argsort.cuda")
if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
strategy.add_implementation(wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort_thrust.cuda",
plevel=15)
return strategy return strategy
@topk_strategy.register(["cuda", "gpu"]) @topk_strategy.register(["cuda", "gpu"])
...@@ -337,6 +343,11 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): ...@@ -337,6 +343,11 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
strategy.add_implementation(wrap_compute_topk(topi.cuda.topk), strategy.add_implementation(wrap_compute_topk(topi.cuda.topk),
wrap_topi_schedule(topi.cuda.schedule_topk), wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk.cuda") name="topk.cuda")
if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
strategy.add_implementation(wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk_thrust.cuda",
plevel=15)
return strategy return strategy
@multibox_prior_strategy.register(["cuda", "gpu"]) @multibox_prior_strategy.register(["cuda", "gpu"])
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file Use external Thrust library call
*/
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>
#include <algorithm>
#include <vector>
namespace tvm {
namespace contrib {
using namespace runtime;
// Performs sorting along axis -1 and returns both sorted values and indices.
template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
bool is_ascend) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));
int n_values = input->shape[input->ndim - 1];
int n_iter = 1;
for (int i = 0; i < input->ndim - 1; ++i) {
n_iter *= input->shape[i];
}
thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);
for (int i = 0 ; i < n_iter; ++i) {
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
} else {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
thrust::greater<DataType>());
}
values_ptr += n_values;
indices_ptr += n_values;
}
}
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[1];
DLTensor* indices_out = args[2];
bool is_ascend = args[3];
auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);
if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<float, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<float, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<double, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<double, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
});
} // namespace contrib
} // namespace tvm
...@@ -42,6 +42,7 @@ ALLOW_EXTENSION = { ...@@ -42,6 +42,7 @@ ALLOW_EXTENSION = {
"pxi", "pxi",
"pyd", "pyd",
"pyx", "pyx",
"cu",
# relay text format # relay text format
"rly", "rly",
# configurations # configurations
......
...@@ -21,7 +21,7 @@ from tvm import te ...@@ -21,7 +21,7 @@ from tvm import te
from .injective import schedule_injective_from_existing from .injective import schedule_injective_from_existing
from ..math import identity from ..math import identity
from ..transform import strided_slice from ..transform import strided_slice, transpose
from .. import tag from .. import tag
def _schedule_sort(outs): def _schedule_sort(outs):
...@@ -291,6 +291,40 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): ...@@ -291,6 +291,40 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
tag="argsort_gpu")[1] tag="argsort_gpu")[1]
return out return out
def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Parameters
----------
data: tvm.te.Tensor
The input array.
valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
if valid_count is not None:
# TODO: implement argsort_nms with Thrust
out = argsort(data, valid_count, axis, is_ascend, dtype)
else:
out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
return out
def schedule_argsort(outs): def schedule_argsort(outs):
"""Schedule for argsort operator. """Schedule for argsort operator.
...@@ -384,6 +418,82 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): ...@@ -384,6 +418,82 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
return output return output
def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.
Parameters
----------
data : tvm.te.Tensor
The input tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : tvm.te.Tensor or List[tvm.te.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis
def swap(arr):
""" swap arr[axis] and arr[-1] """
return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]
if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)))
data = transpose(data, axes)
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_bufs = [
tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
]
out = te.extern([data.shape, data.shape],
[data],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend),
in_buffers=[data_buf],
out_buffers=out_bufs,
name="topk_gpu",
tag="topk_gpu")
if k > 0:
beg = [0] * ndim
end = data.shape[:-1] + [k]
out = [strided_slice(o, beg, end) for o in out]
if axis != ndim - 1:
axes = swap(list(range(ndim)))
out = [transpose(o, axes) for o in out]
if ret_type == "values":
out = out[0]
elif ret_type == "indices":
out = out[1]
return out
def schedule_topk(outs): def schedule_topk(outs):
"""Schedule for argsort operator. """Schedule for argsort operator.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment