Unverified Commit 25d35421 by Leyuan Wang Committed by GitHub

Add thrust support for nms (#5116)

* add argsort_nms_thrust

* consider valid count in thrust nms sort

* make thrust optional

* typo

* typo

* fix pylint

* address some of the comments

* address more comments

* fix lint

* address more comments

* address more comments
parent 2adcb738
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <dlpack/dlpack.h> #include <dlpack/dlpack.h>
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include <functional>
namespace tvm { namespace tvm {
namespace contrib { namespace contrib {
...@@ -39,7 +40,8 @@ template<typename DataType, typename IndicesType> ...@@ -39,7 +40,8 @@ template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input, void thrust_sort(DLTensor* input,
DLTensor* out_values, DLTensor* out_values,
DLTensor* out_indices, DLTensor* out_indices,
bool is_ascend) { bool is_ascend,
const std::function<int(int)> &get_sort_len) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data)); thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data)); thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data)); thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));
...@@ -53,6 +55,7 @@ void thrust_sort(DLTensor* input, ...@@ -53,6 +55,7 @@ void thrust_sort(DLTensor* input,
thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr); thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);
for (int i = 0 ; i < n_iter; ++i) { for (int i = 0 ; i < n_iter; ++i) {
n_values = get_sort_len(i);
thrust::sequence(indices_ptr, indices_ptr + n_values); thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) { if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr); thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
...@@ -65,69 +68,100 @@ void thrust_sort(DLTensor* input, ...@@ -65,69 +68,100 @@ void thrust_sort(DLTensor* input,
} }
} }
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") void thrust_sort_common(DLTensor* input,
.set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* values_out,
CHECK_GE(args.num_args, 4); DLTensor* indices_out,
DLTensor* input = args[0]; bool is_ascend,
DLTensor* values_out = args[1]; const std::function<int(int)> &get_sort_len,
DLTensor* indices_out = args[2]; std::string data_dtype,
bool is_ascend = args[3]; std::string out_dtype) {
auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);
if (data_dtype == "float32") { if (data_dtype == "float32") {
if (out_dtype == "int32") { if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend); thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") { } else if (out_dtype == "int64") {
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend); thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") { } else if (out_dtype == "float32") {
thrust_sort<float, float>(input, values_out, indices_out, is_ascend); thrust_sort<float, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") { } else if (out_dtype == "float64") {
thrust_sort<float, double>(input, values_out, indices_out, is_ascend); thrust_sort<float, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else { } else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype; LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
} }
} else if (data_dtype == "float64") { } else if (data_dtype == "float64") {
if (out_dtype == "int32") { if (out_dtype == "int32") {
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend); thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") { } else if (out_dtype == "int64") {
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend); thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") { } else if (out_dtype == "float32") {
thrust_sort<double, float>(input, values_out, indices_out, is_ascend); thrust_sort<double, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") { } else if (out_dtype == "float64") {
thrust_sort<double, double>(input, values_out, indices_out, is_ascend); thrust_sort<double, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else { } else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype; LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
} }
} else if (data_dtype == "int32") { } else if (data_dtype == "int32") {
if (out_dtype == "int32") { if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend); thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") { } else if (out_dtype == "int64") {
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend); thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") { } else if (out_dtype == "float32") {
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend); thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") { } else if (out_dtype == "float64") {
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend); thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else { } else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype; LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
} }
} else if (data_dtype == "int64") { } else if (data_dtype == "int64") {
if (out_dtype == "int32") { if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend); thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") { } else if (out_dtype == "int64") {
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend); thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") { } else if (out_dtype == "float32") {
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend); thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") { } else if (out_dtype == "float64") {
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend); thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else { } else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype; LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
} }
} else { } else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype; LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
} }
}
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 5);
DLTensor* input = args[0];
DLTensor* valid_count = args[1];
DLTensor* values_out = args[2];
DLTensor* indices_out = args[3];
bool is_ascend = args[4];
auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);
thrust::device_ptr<int> valid_count_ptr(static_cast<int *>(valid_count->data));
auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; };
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
data_dtype, out_dtype);
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[1];
DLTensor* indices_out = args[2];
bool is_ascend = args[3];
auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);
int n_values = input->shape[input->ndim - 1];
auto get_sort_len = [=](int i) { return n_values; };
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
data_dtype, out_dtype);
});
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
...@@ -22,7 +22,7 @@ import tvm ...@@ -22,7 +22,7 @@ import tvm
from tvm import te from tvm import te
from tvm.tir import if_then_else from tvm.tir import if_then_else
from .sort import argsort from .sort import argsort, argsort_thrust
from .. import tag from .. import tag
...@@ -668,6 +668,10 @@ def non_max_suppression(data, valid_count, max_output_size=-1, ...@@ -668,6 +668,10 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
score_shape = (batch_size, num_anchors) score_shape = (batch_size, num_anchors)
score_tensor = te.compute( score_tensor = te.compute(
score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
sort_tensor = argsort_thrust(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
else:
sort_tensor = argsort( sort_tensor = argsort(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False) score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
......
...@@ -24,6 +24,10 @@ from ..math import identity ...@@ -24,6 +24,10 @@ from ..math import identity
from ..transform import strided_slice, transpose from ..transform import strided_slice, transpose
from .. import tag from .. import tag
def swap(arr, axis):
""" swap arr[axis] and arr[-1] """
return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]
def _schedule_sort(outs): def _schedule_sort(outs):
"""Schedule for argsort operator. """Schedule for argsort operator.
...@@ -237,6 +241,64 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): ...@@ -237,6 +241,64 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
return ib.get() return ib.get()
def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Parameters
----------
data: tvm.te.Tensor
The input array.
valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
ndim = len(data.shape)
if axis < 0:
axis = ndim + axis
if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
data_alignment=8)
valid_count_buf = tvm.tir.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4)
out_bufs = [
tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8)
]
out = te.extern([data.shape, data.shape],
[data, valid_count],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend),
in_buffers=[data_buf, valid_count_buf],
out_buffers=out_bufs,
dtype=[data.dtype, "int32"],
name="nms_argsort_gpu",
tag="nms_argsort_gpu")
if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
out = [transpose(o, axes) for o in out]
return out[1]
def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies """Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order. having same shape as an input array that index data in sorted order.
...@@ -318,8 +380,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32" ...@@ -318,8 +380,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"
The output of this function. The output of this function.
""" """
if valid_count is not None: if valid_count is not None:
# TODO: implement argsort_nms with Thrust out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype)
out = argsort(data, valid_count, axis, is_ascend, dtype)
else: else:
out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype) out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
return out return out
...@@ -453,13 +514,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int ...@@ -453,13 +514,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
ndim = len(data.shape) ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis axis = ndim + axis if axis < 0 else axis
def swap(arr):
""" swap arr[axis] and arr[-1] """
return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]
if axis != ndim - 1: if axis != ndim - 1:
# Prepare for sorting along axis -1. # Prepare for sorting along axis -1.
axes = swap(list(range(ndim))) axes = swap(list(range(ndim)), axis)
data = transpose(data, axes) data = transpose(data, axes)
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
...@@ -483,7 +540,7 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int ...@@ -483,7 +540,7 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
out = [strided_slice(o, beg, end) for o in out] out = [strided_slice(o, beg, end) for o in out]
if axis != ndim - 1: if axis != ndim - 1:
axes = swap(list(range(ndim))) axes = swap(list(range(ndim)), axis)
out = [transpose(o, axes) for o in out] out = [transpose(o, axes) for o in out]
if ret_type == "values": if ret_type == "values":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment