Commit 072f8cc7 by Haichen Shen Committed by Leyuan Wang

[Relay/TOPI][Op] Add TopK operator (#3256)

* init impl for topk

* Fix cpu for topk

* init cuda impl for topk

* Add cuda for topk

* fix

* Add doc

* update doc

* lint

* lint

* lint

* x

* fix warning

* [Relay] Add TopK in tf converter

* Add frontend converter

* fix
parent 4204544b
...@@ -99,6 +99,8 @@ List of operators ...@@ -99,6 +99,8 @@ List of operators
topi.shape topi.shape
topi.layout_transform topi.layout_transform
topi.image.resize topi.image.resize
topi.argsort
topi.topk
List of schedules List of schedules
...@@ -163,6 +165,8 @@ topi ...@@ -163,6 +165,8 @@ topi
.. autofunction:: topi.tile .. autofunction:: topi.tile
.. autofunction:: topi.shape .. autofunction:: topi.shape
.. autofunction:: topi.layout_transform .. autofunction:: topi.layout_transform
.. autofunction:: topi.argsort
.. autofunction:: topi.topk
topi.nn topi.nn
~~~~~~~ ~~~~~~~
......
...@@ -172,6 +172,7 @@ This level enables additional math and transform operators. ...@@ -172,6 +172,7 @@ This level enables additional math and transform operators.
:nosignatures: :nosignatures:
tvm.relay.argsort tvm.relay.argsort
tvm.relay.topk
**Level 10: Temporary Operators** **Level 10: Temporary Operators**
...@@ -309,6 +310,7 @@ Level 5 Definitions ...@@ -309,6 +310,7 @@ Level 5 Definitions
Level 6 Definitions Level 6 Definitions
------------------- -------------------
.. autofunction:: tvm.relay.argsort .. autofunction:: tvm.relay.argsort
.. autofunction:: tvm.relay.topk
Level 10 Definitions Level 10 Definitions
......
...@@ -48,6 +48,31 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> { ...@@ -48,6 +48,31 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
} }
}; };
struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
int k;
int axis;
bool is_ascend;
std::string ret_type;
DataType dtype;
TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
TVM_ATTR_FIELD(k).set_default(1)
.describe("Number of top elements to select");
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis along which to sort the input tensor.");
TVM_ATTR_FIELD(ret_type).set_default("both")
.describe("The return type [both, values, indices]."
"both - return both top k data and indices."
"values - return top k data only."
"indices - return top k indices only.");
TVM_ATTR_FIELD(is_ascend).set_default(false)
.describe("Whether to sort in ascending or descending order."
"By default, sort in descending order");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Data type of the output indices.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_ALGORITHM_H_ #endif // TVM_RELAY_ATTRS_ALGORITHM_H_
...@@ -683,6 +683,21 @@ def _mx_argsort(inputs, attrs): ...@@ -683,6 +683,21 @@ def _mx_argsort(inputs, attrs):
return _op.argsort(inputs[0], **new_attrs) return _op.argsort(inputs[0], **new_attrs)
def _mx_topk(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["k"] = attrs.get_int("k", 1)
new_attrs["axis"] = attrs.get_int("axis", -1)
new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
ret_type = attrs.get_str("ret_typ", "indices")
if ret_type == "mask":
raise tvm.error.OpAttributeUnimplemented(
"Attribute ret_type=mask is not supported in topk operator")
new_attrs["ret_type"] = "values" if ret_type == "value" else ret_type
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.topk(inputs[0], **new_attrs)
def _mx_rnn_param_concat(inputs, _): def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op # We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs] return [inputs]
...@@ -914,6 +929,7 @@ _convert_map = { ...@@ -914,6 +929,7 @@ _convert_map = {
"shape_array" : _mx_shape_array, "shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding, "Embedding" : _mx_embedding,
"argsort" : _mx_argsort, "argsort" : _mx_argsort,
"topk" : _mx_topk,
"SoftmaxOutput" : _mx_softmax_output, "SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation, "SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output, "LinearRegressionOutput" : _mx_linear_regression_output,
......
...@@ -1082,6 +1082,20 @@ def _softplus(): ...@@ -1082,6 +1082,20 @@ def _softplus():
return _get_relay_op('log')(add_out) return _get_relay_op('log')(add_out)
return _impl return _impl
def _topk():
def _impl(inputs, attr, params):
k = int(params.pop(inputs.pop(1).name_hint).asnumpy())
if k < 1:
raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2')
if attr['sorted'] is False:
raise tvm.error.OpAttributeUnimplemented(
'Attribute sorted=False is not supported in operator TopKV2')
return AttrCvt(op_name='topk',
ignores=['sorted'],
extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr)
return _impl
def _logical(name): def _logical(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr) return AttrCvt(op_name=name)(inputs, attr)
...@@ -1271,6 +1285,7 @@ _convert_map = { ...@@ -1271,6 +1285,7 @@ _convert_map = {
'Sum' : _sum(), 'Sum' : _sum(),
'Tanh' : AttrCvt('tanh'), 'Tanh' : AttrCvt('tanh'),
'Tile' : _tile(), 'Tile' : _tile(),
'TopKV2' : _topk(),
'Transpose' : _transpose(), 'Transpose' : _transpose(),
'Unpack' : _unpack(), 'Unpack' : _unpack(),
......
...@@ -35,11 +35,31 @@ def compute_argsort(attrs, inputs, _, target): ...@@ -35,11 +35,31 @@ def compute_argsort(attrs, inputs, _, target):
"""Compute definition of argsort""" """Compute definition of argsort"""
axis = get_const_int(attrs.axis) axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend)) is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = str(attrs.dtype) dtype = attrs.dtype
return [ return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \
dtype=dtype, flag=False)
]
register_pattern("argsort", OpPattern.OPAQUE) register_pattern("argsort", OpPattern.OPAQUE)
@register_schedule("topk")
def schedule_topk(_, outs, target):
"""Schedule definition of argsort"""
with target:
return topi.generic.schedule_topk(outs)
@register_compute("topk")
def compute_topk(attrs, inputs, _, target):
"""Compute definition of argsort"""
k = get_const_int(attrs.k)
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype)
out = out if isinstance(out, list) else [out]
return out
register_pattern("topk", OpPattern.OPAQUE)
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
"""Classic algorithm operation""" """Classic algorithm operation"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import _make from . import _make
from ..expr import TupleWrapper
def argsort(data, axis=-1, is_ascend=1, dtype="float32"): def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
"""Performs sorting along the given axis and returns an array of indicies """Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order. having same shape as an input array that index data in sorted order.
...@@ -37,7 +38,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): ...@@ -37,7 +38,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
Whether to sort in ascending or descending order. Whether to sort in ascending or descending order.
dtype : string, optional dtype : string, optional
DType of the output indices. The data type of the output indices.
Returns Returns
------- -------
...@@ -45,3 +46,42 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): ...@@ -45,3 +46,42 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
Tensor with same shape as data. Tensor with same shape as data.
""" """
return _make.argsort(data, axis, is_ascend, dtype) return _make.argsort(data, axis, is_ascend, dtype)
def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
"""Get the top k elements in an input tensor along the given axis.
ret_type specifies the return type, can be one of ("both", "values", "indices").
Parameters
----------
data : relay.Expr
The input data tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : relay.Expr or List[relay.Expr]
The computed result.
"""
out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
if ret_type == "both":
return TupleWrapper(out, 2)
return out
...@@ -401,7 +401,7 @@ def upsampling(data, ...@@ -401,7 +401,7 @@ def upsampling(data,
with data of shape (n, c, h, w) with data of shape (n, c, h, w)
out will have a shape (n, c, h*scale, w*scale) out will have a shape (n, c, h*scale, w*scale)
method indicates the algorithm to be used while calculating ghe out value method indicates the algorithm to be used while calculating the out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR") and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR")
Parameters Parameters
......
...@@ -218,9 +218,9 @@ def take(data, indices, axis=None, mode="clip"): ...@@ -218,9 +218,9 @@ def take(data, indices, axis=None, mode="clip"):
the flattened input array is used. the flattened input array is used.
mode : str, optional mode : str, optional
Specifies how out-of-bound indices will behave. Specifies how out-of-bound indices will behave [clip, wrap].
clip - clip to the range (default) clip: clip to the range (default).
wrap - wrap around the indices wrap: wrap around the indices.
Returns Returns
------- -------
......
...@@ -83,7 +83,7 @@ Target CreateTarget(const std::string& target_name, ...@@ -83,7 +83,7 @@ Target CreateTarget(const std::string& target_name,
t->device_type = kDLGPU; t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImm::make("cuda")); t->keys_array.push_back(ir::StringImm::make("cuda"));
t->keys_array.push_back(ir::StringImm::make("gpu")); t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 512; t->max_num_threads = 1024;
t->thread_warp_size = 32; t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") { } else if (target_name == "rocm" || target_name == "opencl") {
// For now assume rocm schedule for opencl // For now assume rocm schedule for opencl
......
...@@ -34,14 +34,14 @@ namespace contrib { ...@@ -34,14 +34,14 @@ namespace contrib {
using namespace runtime; using namespace runtime;
template<typename DType> template<typename DType>
bool CompareAscend(const std::pair<int32_t, DType>& lhs, bool CompareAscend(const std::pair<int64_t, DType>& lhs,
const std::pair<int32_t, DType>& rhs) { const std::pair<int64_t, DType>& rhs) {
return lhs.second < rhs.second; return lhs.second < rhs.second;
} }
template<typename DType> template<typename DType>
bool CompareDescend(const std::pair<int32_t, DType>& lhs, bool CompareDescend(const std::pair<int64_t, DType>& lhs,
const std::pair<int32_t, DType>& rhs) { const std::pair<int64_t, DType>& rhs) {
return lhs.second > rhs.second; return lhs.second > rhs.second;
} }
...@@ -110,6 +110,41 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") ...@@ -110,6 +110,41 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
} }
}); });
template<typename DataType, typename OutType>
void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
auto data_ptr = static_cast<DataType *>(input->data);
auto out_ptr = static_cast<OutType *>(output->data);
std::vector<std::pair<int64_t, DataType> > sorter;
int axis_mul_before = 1;
int axis_mul_after = 1;
for (int i = 0; i < input->ndim; ++i) {
if (i < axis) {
axis_mul_before *= input->shape[i];
} else if (i > axis) {
axis_mul_after *= input->shape[i];
}
}
for (int i = 0 ; i < axis_mul_before; ++i) {
for (int j = 0 ; j < axis_mul_after; ++j) {
sorter.clear();
int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
for (int64_t k = 0; k < input->shape[axis]; ++k) {
int64_t full_idx = base_idx + k * axis_mul_after;
sorter.emplace_back(std::make_pair(k, data_ptr[full_idx]));
}
if (is_ascend) {
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>);
} else {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
}
for (int64_t k = 0; k < input->shape[axis]; ++k) {
out_ptr[base_idx + k * axis_mul_after] = static_cast<OutType>(sorter[k].first);
}
}
}
}
// Argsort implemented C library sort. // Argsort implemented C library sort.
// Return indices of sorted tensor. // Return indices of sorted tensor.
...@@ -124,25 +159,84 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") ...@@ -124,25 +159,84 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
DLTensor *output = args[1]; DLTensor *output = args[1];
int32_t axis = args[2]; int32_t axis = args[2];
bool is_ascend = args[3]; bool is_ascend = args[3];
auto dtype = input->dtype;
auto data_ptr = static_cast<float *>(input->data);
std::vector<std::pair<float, float>> sorter;
int64_t axis_mul_before = 1;
int64_t axis_mul_after = 1;
if (axis < 0) { if (axis < 0) {
axis = input->ndim + axis; axis = input->ndim + axis;
} }
// Currently only supports input dtype to be float32.
CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
"to be float32.";
CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
"to be float32.";
CHECK_LT(axis, input->ndim) << "Axis out of boundary for " CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
"input ndim " << input->ndim; "input ndim " << input->ndim;
auto data_dtype = TVMType2String(input->dtype);
auto out_dtype = TVMType2String(output->dtype);
if (data_dtype == "float32") {
if (out_dtype == "int32") {
argsort<float, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
argsort<float, int64_t>(input, output, axis, is_ascend);
} else if (out_dtype == "float32") {
argsort<float, float>(input, output, axis, is_ascend);
} else if (out_dtype == "float64") {
argsort<float, double>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
argsort<double, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
argsort<double, int64_t>(input, output, axis, is_ascend);
} else if (out_dtype == "float32") {
argsort<double, float>(input, output, axis, is_ascend);
} else if (out_dtype == "float64") {
argsort<double, double>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
argsort<int32_t, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
argsort<int32_t, int64_t>(input, output, axis, is_ascend);
} else if (out_dtype == "float32") {
argsort<int32_t, float>(input, output, axis, is_ascend);
} else if (out_dtype == "float64") {
argsort<int32_t, double>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
argsort<int64_t, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
argsort<int64_t, int64_t>(input, output, axis, is_ascend);
} else if (out_dtype == "float32") {
argsort<int64_t, float>(input, output, axis, is_ascend);
} else if (out_dtype == "float64") {
argsort<int64_t, double>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
});
template<typename DataType, typename IndicesType>
void topk(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
int k,
int axis,
bool is_ascend) {
DataType* data_ptr = static_cast<DataType *>(input->data);
DataType* values_ptr = (out_values == nullptr) ? nullptr :
static_cast<DataType *>(out_values->data);
IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr :
static_cast<IndicesType *>(out_indices->data);
std::vector<std::pair<int64_t, DataType> > sorter;
int axis_mul_before = 1;
int axis_mul_after = 1;
for (int i = 0; i < input->ndim; ++i) { for (int i = 0; i < input->ndim; ++i) {
if (i < axis) { if (i < axis) {
axis_mul_before *= input->shape[i]; axis_mul_before *= input->shape[i];
...@@ -150,26 +244,123 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") ...@@ -150,26 +244,123 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
axis_mul_after *= input->shape[i]; axis_mul_after *= input->shape[i];
} }
} }
if (k < 1) {
k = input->shape[axis];
}
int32_t current_sort_num = input->shape[axis]; for (int i = 0 ; i < axis_mul_before; ++i) {
for (int64_t i = 0 ; i < axis_mul_before; ++i) { for (int j = 0 ; j < axis_mul_after; ++j) {
for (int64_t j = 0 ; j < axis_mul_after; ++j) {
sorter.clear(); sorter.clear();
int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j;
for (int64_t k = 0; k < current_sort_num; ++k) { int64_t dst_base_idx = i * k * axis_mul_after + j;
int64_t full_idx = base_idx + k * axis_mul_after; for (int64_t kk = 0; kk < input->shape[axis]; ++kk) {
sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); int64_t full_idx = src_base_idx + kk * axis_mul_after;
sorter.emplace_back(std::make_pair(kk, data_ptr[full_idx]));
} }
if (is_ascend) { if (is_ascend) {
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>); std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>);
} else { } else {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>); std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
} }
for (int32_t k = 0; k < input->shape[axis]; ++k) { int64_t cnt = k > 0 ? k : input->shape[axis];
*(static_cast<float *>(output->data) + base_idx + k * axis_mul_after) for (int64_t kk = 0; kk < cnt; ++kk) {
= k < static_cast<float>(sorter.size()) ? sorter[k].first : k; if (indices_ptr != nullptr) {
indices_ptr[dst_base_idx + kk * axis_mul_after] =
static_cast<IndicesType>(sorter[kk].first);
}
if (values_ptr != nullptr) {
values_ptr[dst_base_idx + kk * axis_mul_after] =
static_cast<DataType>(sorter[kk].second);
}
}
}
}
}
// Argsort implemented C library sort.
// Return indices of sorted tensor.
// By default, the last axis will be used to sort.
// sort_num specify the number of elements to be sorted.
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
// and sort axis is dk. sort_num should have dimension of
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* input = args[0];
DLTensor* values_out = nullptr;
DLTensor* indices_out = nullptr;
int k = args[args.num_args - 4];
int axis = args[args.num_args - 3];
std::string ret_type = args[args.num_args - 2];
bool is_ascend = args[args.num_args - 1];
if (ret_type == "both") {
values_out = args[1];
indices_out = args[2];
} else if (ret_type == "values") {
values_out = args[1];
} else if (ret_type == "indices") {
indices_out = args[1];
} else {
LOG(FATAL) << "Unsupported ret type: " << ret_type;
} }
if (axis < 0) {
axis = input->ndim + axis;
} }
CHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim;
auto data_dtype = TVMType2String(input->dtype);
auto out_dtype = (indices_out == nullptr) ? "int64" : TVMType2String(indices_out->dtype);
if (data_dtype == "float32") {
if (out_dtype == "int32") {
topk<float, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
topk<float, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float32") {
topk<float, float>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float64") {
topk<float, double>(input, values_out, indices_out, k, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
topk<double, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
topk<double, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float32") {
topk<double, float>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float64") {
topk<double, double>(input, values_out, indices_out, k, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
topk<int32_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
topk<int32_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float32") {
topk<int32_t, float>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float64") {
topk<int32_t, double>(input, values_out, indices_out, k, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
topk<int64_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
topk<int64_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float32") {
topk<int64_t, float>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float64") {
topk<int64_t, double>(input, values_out, indices_out, k, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
} }
}); });
......
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2019 by Contributors
* \file nms.cc * \file argsort.cc
* \brief Non-maximum suppression operators * \brief Argsort operators
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h> #include <tvm/relay/attrs/algorithm.h>
...@@ -44,7 +44,6 @@ bool ArgsortRel(const Array<Type>& types, ...@@ -44,7 +44,6 @@ bool ArgsortRel(const Array<Type>& types,
<< types[0]; << types[0];
return false; return false;
} }
CHECK_EQ(param->dtype, Float(32));
reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype)); reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
return true; return true;
} }
...@@ -74,5 +73,6 @@ input array along the given axis. ...@@ -74,5 +73,6 @@ input array along the given axis.
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.set_support_level(6) .set_support_level(6)
.add_type_rel("Argsort", ArgsortRel); .add_type_rel("Argsort", ArgsortRel);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file topk.cc
* \brief TopK operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(TopKAttrs);
bool TopKRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
const TopKAttrs* param = attrs.as<TopKAttrs>();
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data);
int ndim = data->shape.size();
int axis = param->axis;
if (axis < 0) {
axis += ndim;
}
CHECK(axis >= 0 && axis < ndim);
Array<IndexExpr> out_shape;
for (int i = 0; i < ndim; ++i) {
if (i != axis || param->k < 1) {
out_shape.push_back(data->shape[i]);
} else {
out_shape.push_back(param->k);
}
}
auto values_ty = TensorTypeNode::make(out_shape, data->dtype);
auto indices_ty = TensorTypeNode::make(out_shape, param->dtype);
if (param->ret_type == "both") {
reporter->Assign(types[1], TupleTypeNode::make({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
reporter->Assign(types[1], values_ty);
} else if (param->ret_type == "indices") {
reporter->Assign(types[1], indices_ty);
} else {
LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
}
return true;
}
Expr MakeTopK(Expr data,
int k,
int axis,
std::string ret_type,
bool is_ascend,
DataType dtype) {
auto attrs = make_node<TopKAttrs>();
attrs->k = k;
attrs->axis = axis;
attrs->ret_type = ret_type;
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("topk");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.topk")
.set_body_typed(MakeTopK);
RELAY_REGISTER_OP("topk")
.describe(R"doc(Get the top k elements in an input tensor along the given axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.TopKAttrs")
.add_argument("data", "Tensor", "Input data.")
.set_support_level(6)
.add_type_rel("TopK", TopKRel);
} // namespace relay
} // namespace tvm
...@@ -608,6 +608,45 @@ def test_forward_Crop(): ...@@ -608,6 +608,45 @@ def test_forward_Crop():
verify((5, 32, 40, 40), (5, 32, 25, 25)) verify((5, 32, 40, 40), (5, 32, 25, 25))
verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5)) verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5))
def test_forward_argsort():
def verify(shape, axis, is_ascend, dtype="float32"):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.argsort(mx.nd.array(x_np), axis=axis, is_ascend=is_ascend, dtype=dtype)
mx_sym = mx.sym.argsort(mx.sym.var("x"), axis=axis, is_ascend=is_ascend, dtype=dtype)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 3, 4), axis=0, is_ascend=False)
verify((1, 4, 6), axis=1, is_ascend=True)
verify((3, 5, 6), axis=-3, is_ascend=False, dtype="int32")
def test_forward_topk():
def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type,
is_ascend=is_ascend, dtype=dtype)
mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type,
is_ascend=is_ascend, dtype=dtype)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
if isinstance(ref_res, list):
assert len(op_res) == len(ref_res)
for i, t in enumerate(op_res):
tvm.testing.assert_allclose(t.asnumpy(), ref_res[i].asnumpy())
else:
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((3, 4), k=1, axis=0, ret_type="both")
verify((3, 4), k=1, axis=-1, ret_type="indices")
verify((3, 5, 6), k=2, axis=2, ret_type="value")
verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True)
verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
...@@ -650,3 +689,5 @@ if __name__ == '__main__': ...@@ -650,3 +689,5 @@ if __name__ == '__main__':
test_forward_bilinear_resize() test_forward_bilinear_resize()
test_forward_rnn_layer() test_forward_rnn_layer()
test_forward_Crop() test_forward_Crop()
test_forward_argsort()
test_forward_topk()
...@@ -754,6 +754,24 @@ def test_forward_split(): ...@@ -754,6 +754,24 @@ def test_forward_split():
_test_split((3, 6, 4), -2, [1, 4, 1], 'float32') _test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
######################################################################
# TopKV2
# ------
def _test_forward_top_k_v2(in_shape, k):
np_data = np.random.uniform(-100, 100, size=in_shape).astype("float32")
tf.reset_default_graph()
in_data = tf.placeholder("float32", in_shape, name="in_data")
tf.math.top_k(in_data, k, name='TopK')
compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')
def test_forward_top_k_v2():
_test_forward_top_k_v2((3,), 1)
_test_forward_top_k_v2((3,), 3)
_test_forward_top_k_v2((3, 5, 7), 3)
_test_forward_top_k_v2((3, 5, 7), 3)
####################################################################### #######################################################################
# Unstack # Unstack
# ------- # -------
...@@ -1704,6 +1722,7 @@ if __name__ == '__main__': ...@@ -1704,6 +1722,7 @@ if __name__ == '__main__':
test_forward_split() test_forward_split()
test_forward_unstack() test_forward_unstack()
test_forward_tile() test_forward_tile()
test_forward_top_k_v2()
# Activations # Activations
test_forward_sigmoid() test_forward_sigmoid()
......
...@@ -16,18 +16,15 @@ ...@@ -16,18 +16,15 @@
# under the License. # under the License.
""" Support level6 operator test cases. """ Support level6 operator test cases.
""" """
import math
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list
import topi.testing
def test_argsort(): def test_argsort():
def verify_argsort(shape, axis, is_ascend): def verify_argsort(shape, axis, is_ascend, dtype):
x = relay.var("x", relay.TensorType(shape, "float32")) x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.argsort(x, axis=axis, is_ascend=is_ascend) z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype)
zz = relay.ir_pass.infer_type(z)
func = relay.Function([x], z) func = relay.Function([x], z)
x_data = np.random.uniform(size=shape).astype("float32") x_data = np.random.uniform(size=shape).astype("float32")
if is_ascend: if is_ascend:
...@@ -39,11 +36,58 @@ def test_argsort(): ...@@ -39,11 +36,58 @@ def test_argsort():
for kind in ["graph", "debug"]: for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target) intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data) op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype(dtype), rtol=1e-5)
verify_argsort((2, 3, 4), axis=0, is_ascend=False) for dtype in ["int32", "int64", "float32", "float64"]:
verify_argsort((1, 4, 6), axis=1, is_ascend=True) verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype)
verify_argsort((3, 5, 6), axis=-1, is_ascend=False) verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype)
verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype)
def test_topk():
def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100)
x = relay.var("x", relay.TensorType(shape, "float32"))
out = relay.topk(x, k, axis, ret_type, is_ascend, dtype)
if isinstance(out, relay.expr.TupleWrapper):
out = out.astuple()
func = relay.Function([x], out)
np_data = np.random.uniform(size=shape).astype("float32")
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
else:
np_indices = np.argsort(-np_data, axis=axis)
kk = k if k >= 1 else shape[axis]
if axis == 0:
np_indices = np_indices[:kk, :]
np_values = np.zeros(np_indices.shape).astype("float32")
for i in range(shape[1]):
np_values[:, i] = np_data[np_indices[:, i], i]
else:
np_indices = np_indices[:, :kk]
np_values = np.zeros(np_indices.shape).astype("float32")
for i in range(shape[0]):
np_values[i, :] = np_data[i, np_indices[i, :]]
np_indices = np_indices.astype(dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(np_data)
if ret_type == "both":
tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values)
tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices)
elif ret_type == "values":
tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
else:
tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
for k in [0, 1, 5]:
for axis in [0, -1, 1]:
for ret_type in ["both", "values", "indices"]:
for dtype in ["int64", "float32"]:
verify_topk(k, axis, ret_type, False, dtype)
verify_topk(k, axis, ret_type, True, dtype)
if __name__ == "__main__": if __name__ == "__main__":
test_argsort() test_argsort()
test_topk()
...@@ -21,3 +21,4 @@ from . import ssd ...@@ -21,3 +21,4 @@ from . import ssd
from .ssd import * from .ssd import *
from .nms import * from .nms import *
from .rcnn import * from .rcnn import *
from .sort import *
...@@ -732,7 +732,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, ...@@ -732,7 +732,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
score_axis = score_index score_axis = score_index
score_shape = (batch_size, num_anchors) score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8) "sort_tensor_buf", data_alignment=8)
......
...@@ -19,19 +19,48 @@ ...@@ -19,19 +19,48 @@
import tvm import tvm
from tvm import api from tvm import api
from topi.sort import argsort from ..sort import argsort, topk
from topi.math import identity from ..math import identity
from ..transform import strided_slice
from .. import generic from .. import generic
from .. import tag from .. import tag
def _schedule_sort(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
if tag.is_injective(op.tag):
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
for out in outs:
traverse(out.op)
return s
def sort_ir(data, output, axis, is_ascend): def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
Parameters Parameters
---------- ----------
data: Buffer data: Buffer
Buffer of input data. Buffer of input data. Data will be sorted in place.
output : Buffer output : Buffer
Output buffer of indicies of sorted tensor with same shape as data. Output buffer of indicies of sorted tensor with same shape as data.
...@@ -47,14 +76,12 @@ def sort_ir(data, output, axis, is_ascend): ...@@ -47,14 +76,12 @@ def sort_ir(data, output, axis, is_ascend):
stmt : Stmt stmt : Stmt
The result IR statement. The result IR statement.
""" """
size = 1
axis_mul_before = 1 axis_mul_before = 1
axis_mul_after = 1 axis_mul_after = 1
shape = data.shape shape = data.shape
if axis < 0: if axis < 0:
axis = len(shape) + axis axis = len(shape) + axis
for i, value in enumerate(shape, 0): for i, value in enumerate(shape, 0):
size *= value
if i < axis: if i < axis:
axis_mul_before *= value axis_mul_before *= value
elif i > axis: elif i > axis:
...@@ -62,52 +89,62 @@ def sort_ir(data, output, axis, is_ascend): ...@@ -62,52 +89,62 @@ def sort_ir(data, output, axis, is_ascend):
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
data = ib.buffer_ptr(data) data = ib.buffer_ptr(data)
output = ib.buffer_ptr(output) values_out = ib.buffer_ptr(values_out)
if indices_out is not None:
indices_out = ib.buffer_ptr(indices_out)
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = size // max_threads + 1 nthread_bx = shape[axis] // max_threads + 1
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("vthread") bx = tvm.thread_axis("vthread")
ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "virtual_thread", nthread_bx) ib.scope_attr(bx, "virtual_thread", nthread_bx)
tid = bx * nthread_tx + tx tid = bx * nthread_tx + tx
temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local")
temp_index = ib.allocate("float32", (1,), name="temp_index", scope="local") if indices_out is not None:
is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend) temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local")
with ib.for_range(0, axis_mul_before) as i: with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j: with ib.for_range(0, axis_mul_after) as j:
current_sort_num = shape[axis]
base_idx = i * shape[axis] * axis_mul_after + j base_idx = i * shape[axis] * axis_mul_after + j
with ib.if_scope(tid < shape[axis]): with ib.if_scope(tid < shape[axis]):
output[base_idx + tid * axis_mul_after] = tid.astype("float32") values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after]
if indices_out is not None:
indices_out[base_idx + tid * axis_mul_after] = \
tvm.generic.cast(tid, indices_out.dtype)
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j:
current_sort_num = shape[axis]
base_idx = i * shape[axis] * axis_mul_after + j
# OddEvenTransposeSort # OddEvenTransposeSort
with ib.for_range(0, current_sort_num) as k: with ib.for_range(0, current_sort_num) as k:
with ib.if_scope(tid < (current_sort_num + 1) // 2): with ib.if_scope(tid < (current_sort_num + 1) // 2):
offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
with ib.if_scope(tvm.all(is_ascend == 1, \ if is_ascend:
2 * tid + (k % 2) + 1 < current_sort_num, \ cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num,
data[offset] > data[offset + axis_mul_after])): values_out[offset] > values_out[offset + axis_mul_after])
temp_data[0] = data[offset] else:
data[offset] = data[offset + axis_mul_after] cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num,
data[offset + axis_mul_after] = temp_data[0] values_out[offset] < values_out[offset + axis_mul_after])
temp_index[0] = output[offset] with ib.if_scope(cond):
output[offset] = output[offset + axis_mul_after] temp_data[0] = values_out[offset]
output[offset + axis_mul_after] = temp_index[0] values_out[offset] = values_out[offset + axis_mul_after]
with ib.if_scope(tvm.all(is_ascend == 0, \ values_out[offset + axis_mul_after] = temp_data[0]
2 * tid + (k % 2) + 1 < current_sort_num, \ if indices_out is not None:
data[offset] < data[offset + axis_mul_after])): temp_index[0] = indices_out[offset]
temp_data[0] = data[offset] indices_out[offset] = indices_out[offset + axis_mul_after]
data[offset] = data[offset + axis_mul_after] indices_out[offset + axis_mul_after] = temp_index[0]
data[offset + axis_mul_after] = temp_data[0]
temp_index[0] = output[offset]
output[offset] = output[offset + axis_mul_after]
output[offset + axis_mul_after] = temp_index[0]
ib.emit(tvm.make.Call(None, 'tvm_storage_sync', ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']), tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0)) tvm.expr.Call.Intrinsic, None, 0))
return ib.get() return ib.get()
def sort_nms_ir(data, valid_count, output, axis, is_ascend): def sort_nms_ir(data, valid_count, output, axis, is_ascend):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
...@@ -197,7 +234,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): ...@@ -197,7 +234,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
return ib.get() return ib.get()
@argsort.register(["cuda", "gpu"]) @argsort.register(["cuda", "gpu"])
def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): def argsort_gpu(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies """Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order. having same shape as an input array that index data in sorted order.
...@@ -206,26 +243,27 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0 ...@@ -206,26 +243,27 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0
data: tvm.Tensor data: tvm.Tensor
The input array. The input array.
valid_count : tvm.Tensor valid_count : tvm.Tensor, optional
The number of valid elements to be sorted. The number of valid elements to be sorted.
axis : int axis : int, optional
Axis long which to sort the input tensor. Axis long which to sort the input tensor.
is_ascend : boolean is_ascend : boolean, optional
Whether to sort in ascending or descending order. Whether to sort in ascending or descending order.
flag : boolean dtype : string, optional
Whether this argsort is used in nms operator DType of the output indices.
Returns Returns
------- -------
out : tvm.Tensor out : tvm.Tensor
The output of this function. The output of this function.
""" """
sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) if valid_count is not None:
sorted_data = identity(data) sorted_data = identity(data)
if flag: sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf",
data_alignment=8)
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4) "valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)
...@@ -239,16 +277,15 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0 ...@@ -239,16 +277,15 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0
name="argsort_nms_gpu", name="argsort_nms_gpu",
tag="argsort_nms_gpu") tag="argsort_nms_gpu")
else: else:
out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) value_buf = api.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
out = tvm.extern([data.shape], indices_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
[sorted_data], out = tvm.extern([data.shape, data.shape],
[data],
lambda ins, outs: sort_ir( lambda ins, outs: sort_ir(
ins[0], outs[0], axis, is_ascend), ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
dtype=dtype, out_buffers=[value_buf, indices_buf],
in_buffers=[sorted_data_buf],
out_buffers=[out_buf],
name="argsort_gpu", name="argsort_gpu",
tag="argsort_gpu") tag="argsort_gpu")[1]
return out return out
@generic.schedule_argsort.register(["cuda", "gpu"]) @generic.schedule_argsort.register(["cuda", "gpu"])
...@@ -266,17 +303,99 @@ def schedule_argsort(outs): ...@@ -266,17 +303,99 @@ def schedule_argsort(outs):
s: Schedule s: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs return _schedule_sort(outs)
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
if tag.is_broadcast(op.tag):
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s @topk.register(["cuda", "gpu"])
def topk_gpu(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.
Parameters
----------
data : tvm.Tensor
The input tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : tvm.Tensor or List[tvm.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
ndim = len(data.shape)
axis = axis + ndim if axis < 0 else axis
assert 0 <= axis < ndim
values_buf = api.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
indices_buf = api.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
if ret_type == "values":
output = tvm.extern([data.shape],
[data],
lambda ins, outs: sort_ir(
ins[0], outs[0], axis, is_ascend),
out_buffers=[values_buf],
name="topk_gpu",
tag="topk_gpu")
else:
output = tvm.extern([data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(
ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
out_buffers=[values_buf, indices_buf],
name="topk_gpu",
tag="topk_gpu")
if k < 1:
if ret_type == "indices":
return output[1]
return output
beg = [0] * ndim
end = []
for i in range(ndim):
if i == axis:
end.append(k)
else:
end.append(data.shape[i])
if ret_type == "both":
values_out, indices_out = output
values_out = strided_slice(values_out, beg, end)
indices_out = strided_slice(indices_out, beg, end)
output = [values_out, indices_out]
elif ret_type == "values":
output = [strided_slice(output, beg, end)]
else: # ret_type == "indices"
indices_out = output[1]
output = [strided_slice(indices_out, beg, end)]
return output
@generic.schedule_topk.register(["cuda", "gpu"])
def schedule_topk(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _schedule_sort(outs)
...@@ -36,3 +36,20 @@ def schedule_argsort(outs): ...@@ -36,3 +36,20 @@ def schedule_argsort(outs):
The computation schedule for the op. The computation schedule for the op.
""" """
return _default_schedule(outs, False) return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_topk(outs):
"""Schedule for topk operator.
Parameters
----------
outs: Array of Tensor
The indices that would sort an input array along
the given axis.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
...@@ -18,9 +18,10 @@ ...@@ -18,9 +18,10 @@
"""Argsort operator""" """Argsort operator"""
import tvm import tvm
from tvm import api from tvm import api
from .util import get_const_tuple
@tvm.target.generic_func @tvm.target.generic_func
def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array """Performs sorting along the given axis and returns an array
of indices having the same shape as an input array that index of indices having the same shape as an input array that index
data in sorted order. data in sorted order.
...@@ -30,22 +31,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): ...@@ -30,22 +31,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
data : tvm.Tensor data : tvm.Tensor
The input tensor. The input tensor.
valid_count : tvm.Tensor valid_count : tvm.Tensor, optional
1-D tensor for valid number of boxes only for ssd. 1-D tensor for valid number of boxes only for ssd.
axis : optional, int axis : int, optional
Axis along which to sort the input tensor. Axis along which to sort the input tensor.
By default the flattened array is used. By default the flattened array is used.
is_ascend : optional, boolean is_ascend : boolean, optional
Whether to sort in ascending or descending order. Whether to sort in ascending or descending order.
dtype : optional, string dtype : string, optional
DType of the output indices. DType of the output indices.
flag : optional, boolean
Whether valid_count is valid.
Returns Returns
------- -------
out : tvm.Tensor out : tvm.Tensor
...@@ -58,23 +56,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): ...@@ -58,23 +56,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
# An example to use argsort # An example to use argsort
dshape = (1, 5, 6) dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data") data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
axis = 0 axis = 0
is_ascend = False is_ascend = False
flag = False out = argsort(data, axis=axis, is_ascend=is_ascend)
out = argsort(data, valid_count, axis, is_ascend, flag)
np_data = np.random.uniform(dshape) np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_argsort(out) s = topi.generic.schedule_argsort(out)
f = tvm.build(s, [data, valid_count, out], "llvm") f = tvm.build(s, [data, out], "llvm")
ctx = tvm.cpu() ctx = tvm.cpu()
tvm_data = tvm.nd.array(np_data, ctx) tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f(tvm_data, tvm_valid_count, tvm_out) f(tvm_data, tvm_out)
""" """
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
if flag: if valid_count is not None:
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4) "valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8)
...@@ -103,3 +97,58 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): ...@@ -103,3 +97,58 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
name="argsort_cpu", name="argsort_cpu",
tag="argsort_cpu") tag="argsort_cpu")
return out return out
@tvm.target.generic_func
def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.
Parameters
----------
data : tvm.Tensor
The input tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : tvm.Tensor or List[tvm.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_shape = list(get_const_tuple(data.shape))
if k >= 1:
out_shape[axis] = k
out_bufs = []
if ret_type in ["both", "values"]:
out_bufs.append(api.decl_buffer(out_shape, data.dtype, "value_buf", data_alignment=8))
if ret_type in ["both", "indices"]:
out_bufs.append(api.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8))
out_shapes = [out_shape] * len(out_bufs)
out = tvm.extern(out_shapes,
[data],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.topk", ins[0], *outs, k, axis, ret_type, is_ascend),
in_buffers=[data_buf],
out_buffers=out_bufs,
name="topk_cpu",
tag="topk_cpu")
return out
...@@ -151,6 +151,8 @@ def strided_slice(a, begin, end, strides=None): ...@@ -151,6 +151,8 @@ def strided_slice(a, begin, end, strides=None):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
if strides is None:
strides = []
return cpp.strided_slice(a, begin, end, strides) return cpp.strided_slice(a, begin, end, strides)
......
...@@ -331,7 +331,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1, ...@@ -331,7 +331,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
score_axis = score_index score_axis = score_index
score_shape = (batch_size, num_anchors) score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
out, box_indices = hybrid_nms(data, sort_tensor, valid_count, out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
tvm.const(max_output_size, dtype="int32"), tvm.const(max_output_size, dtype="int32"),
tvm.const(iou_threshold, dtype="float32"), tvm.const(iou_threshold, dtype="float32"),
......
...@@ -16,23 +16,15 @@ ...@@ -16,23 +16,15 @@
# under the License. # under the License.
"""Test code for vision package""" """Test code for vision package"""
from __future__ import print_function from __future__ import print_function
import math
import numpy as np import numpy as np
import tvm import tvm
import topi import topi
import topi.testing import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from topi import argsort
def test_argsort(): def test_argsort():
dshape = (1, 8) dshape = (20, 100)
valid_count_shape = (2,)
data = tvm.placeholder(dshape, name="data", dtype="float32") data = tvm.placeholder(dshape, name="data", dtype="float32")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype) np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.argsort(-np_data) np_result = np.argsort(-np_data)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -41,19 +33,77 @@ def test_argsort(): ...@@ -41,19 +33,77 @@ def test_argsort():
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False) out = topi.argsort(data, axis=-1, is_ascend=False)
s = topi.generic.schedule_argsort(out) s = topi.generic.schedule_argsort(out)
tvm_data = tvm.nd.array(np_data, ctx) tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx)
f = tvm.build(s, [data, valid_count, out], device) f = tvm.build(s, [data, out], device)
f(tvm_data, tvm_valid_count, tvm_out) f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0)
for device in ['llvm', 'cuda', 'opencl']: for device in ['llvm', 'cuda', 'opencl']:
check_device(device) check_device(device)
def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100)
data_dtype = "float32"
data = tvm.placeholder(shape, name="data", dtype=data_dtype)
np_data = np.random.uniform(size=shape).astype(data_dtype)
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
else:
np_indices = np.argsort(-np_data, axis=axis)
kk = k if k >= 1 else shape[axis]
if axis == 0:
np_indices = np_indices[:kk, :]
np_values = np.zeros(np_indices.shape).astype(data_dtype)
for i in range(shape[1]):
np_values[:, i] = np_data[np_indices[:, i], i]
else:
np_indices = np_indices[:, :kk]
np_values = np.zeros(np_indices.shape).astype(data_dtype)
for i in range(shape[0]):
np_values[i, :] = np_data[i, np_indices[i, :]]
np_indices = np_indices.astype(dtype)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
outs = topi.topk(data, k, axis, ret_type, is_ascend, dtype)
outs = outs if isinstance(outs, list) else [outs]
s = topi.generic.schedule_topk(outs)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_res = []
for t in outs:
tvm_res.append(tvm.nd.empty(t.shape, dtype=t.dtype, ctx=ctx))
f = tvm.build(s, [data] + outs, device)
f(tvm_data, *tvm_res)
if ret_type == "both":
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values)
tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_indices)
elif ret_type == "values":
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values)
else:
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices)
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
def test_topk():
for k in [0, 1, 5]:
for axis in [0, -1, 1]:
for ret_type in ["both", "values", "indices"]:
for dtype in ["int64", "float32"]:
verify_topk(k, axis, ret_type, True, dtype)
verify_topk(k, axis, ret_type, False, dtype)
if __name__ == "__main__": if __name__ == "__main__":
test_argsort() test_argsort()
test_topk()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment