Commit 072f8cc7 by Haichen Shen Committed by Leyuan Wang

[Relay/TOPI][Op] Add TopK operator (#3256)

* init impl for topk

* Fix cpu for topk

* init cuda impl for topk

* Add cuda for topk

* fix

* Add doc

* update doc

* lint

* lint

* lint

* x

* fix warning

* [Relay] Add TopK in tf converter

* Add frontend converter

* fix
parent 4204544b
......@@ -99,6 +99,8 @@ List of operators
topi.shape
topi.layout_transform
topi.image.resize
topi.argsort
topi.topk
List of schedules
......@@ -163,6 +165,8 @@ topi
.. autofunction:: topi.tile
.. autofunction:: topi.shape
.. autofunction:: topi.layout_transform
.. autofunction:: topi.argsort
.. autofunction:: topi.topk
topi.nn
~~~~~~~
......
......@@ -172,6 +172,7 @@ This level enables additional math and transform operators.
:nosignatures:
tvm.relay.argsort
tvm.relay.topk
**Level 10: Temporary Operators**
......@@ -309,6 +310,7 @@ Level 5 Definitions
Level 6 Definitions
-------------------
.. autofunction:: tvm.relay.argsort
.. autofunction:: tvm.relay.topk
Level 10 Definitions
......
......@@ -48,6 +48,31 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
}
};
struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
int k;
int axis;
bool is_ascend;
std::string ret_type;
DataType dtype;
TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
TVM_ATTR_FIELD(k).set_default(1)
.describe("Number of top elements to select");
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis along which to sort the input tensor.");
TVM_ATTR_FIELD(ret_type).set_default("both")
.describe("The return type [both, values, indices]."
"both - return both top k data and indices."
"values - return top k data only."
"indices - return top k indices only.");
TVM_ATTR_FIELD(is_ascend).set_default(false)
.describe("Whether to sort in ascending or descending order."
"By default, sort in descending order");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Data type of the output indices.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ALGORITHM_H_
......@@ -683,6 +683,21 @@ def _mx_argsort(inputs, attrs):
return _op.argsort(inputs[0], **new_attrs)
def _mx_topk(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["k"] = attrs.get_int("k", 1)
new_attrs["axis"] = attrs.get_int("axis", -1)
new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
ret_type = attrs.get_str("ret_typ", "indices")
if ret_type == "mask":
raise tvm.error.OpAttributeUnimplemented(
"Attribute ret_type=mask is not supported in topk operator")
new_attrs["ret_type"] = "values" if ret_type == "value" else ret_type
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.topk(inputs[0], **new_attrs)
def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs]
......@@ -914,6 +929,7 @@ _convert_map = {
"shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"topk" : _mx_topk,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output,
......
......@@ -1082,6 +1082,20 @@ def _softplus():
return _get_relay_op('log')(add_out)
return _impl
def _topk():
def _impl(inputs, attr, params):
k = int(params.pop(inputs.pop(1).name_hint).asnumpy())
if k < 1:
raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2')
if attr['sorted'] is False:
raise tvm.error.OpAttributeUnimplemented(
'Attribute sorted=False is not supported in operator TopKV2')
return AttrCvt(op_name='topk',
ignores=['sorted'],
extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr)
return _impl
def _logical(name):
def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
......@@ -1271,6 +1285,7 @@ _convert_map = {
'Sum' : _sum(),
'Tanh' : AttrCvt('tanh'),
'Tile' : _tile(),
'TopKV2' : _topk(),
'Transpose' : _transpose(),
'Unpack' : _unpack(),
......
......@@ -35,11 +35,31 @@ def compute_argsort(attrs, inputs, _, target):
"""Compute definition of argsort"""
axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = str(attrs.dtype)
return [
topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \
dtype=dtype, flag=False)
]
dtype = attrs.dtype
return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
register_pattern("argsort", OpPattern.OPAQUE)
@register_schedule("topk")
def schedule_topk(_, outs, target):
"""Schedule definition of argsort"""
with target:
return topi.generic.schedule_topk(outs)
@register_compute("topk")
def compute_topk(attrs, inputs, _, target):
"""Compute definition of argsort"""
k = get_const_int(attrs.k)
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype)
out = out if isinstance(out, list) else [out]
return out
register_pattern("topk", OpPattern.OPAQUE)
......@@ -17,8 +17,9 @@
"""Classic algorithm operation"""
from __future__ import absolute_import as _abs
from . import _make
from ..expr import TupleWrapper
def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
......@@ -37,7 +38,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
The data type of the output indices.
Returns
-------
......@@ -45,3 +46,42 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
Tensor with same shape as data.
"""
return _make.argsort(data, axis, is_ascend, dtype)
def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
"""Get the top k elements in an input tensor along the given axis.
ret_type specifies the return type, can be one of ("both", "values", "indices").
Parameters
----------
data : relay.Expr
The input data tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : relay.Expr or List[relay.Expr]
The computed result.
"""
out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
if ret_type == "both":
return TupleWrapper(out, 2)
return out
......@@ -401,7 +401,7 @@ def upsampling(data,
with data of shape (n, c, h, w)
out will have a shape (n, c, h*scale, w*scale)
method indicates the algorithm to be used while calculating ghe out value
method indicates the algorithm to be used while calculating the out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR")
Parameters
......
......@@ -218,9 +218,9 @@ def take(data, indices, axis=None, mode="clip"):
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
Specifies how out-of-bound indices will behave [clip, wrap].
clip: clip to the range (default).
wrap: wrap around the indices.
Returns
-------
......
......@@ -83,7 +83,7 @@ Target CreateTarget(const std::string& target_name,
t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImm::make("cuda"));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 512;
t->max_num_threads = 1024;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
// For now assume rocm schedule for opencl
......
......@@ -34,14 +34,14 @@ namespace contrib {
using namespace runtime;
template<typename DType>
bool CompareAscend(const std::pair<int32_t, DType>& lhs,
const std::pair<int32_t, DType>& rhs) {
bool CompareAscend(const std::pair<int64_t, DType>& lhs,
const std::pair<int64_t, DType>& rhs) {
return lhs.second < rhs.second;
}
template<typename DType>
bool CompareDescend(const std::pair<int32_t, DType>& lhs,
const std::pair<int32_t, DType>& rhs) {
bool CompareDescend(const std::pair<int64_t, DType>& lhs,
const std::pair<int64_t, DType>& rhs) {
return lhs.second > rhs.second;
}
......@@ -110,6 +110,41 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
}
});
template<typename DataType, typename OutType>
void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
auto data_ptr = static_cast<DataType *>(input->data);
auto out_ptr = static_cast<OutType *>(output->data);
std::vector<std::pair<int64_t, DataType> > sorter;
int axis_mul_before = 1;
int axis_mul_after = 1;
for (int i = 0; i < input->ndim; ++i) {
if (i < axis) {
axis_mul_before *= input->shape[i];
} else if (i > axis) {
axis_mul_after *= input->shape[i];
}
}
for (int i = 0 ; i < axis_mul_before; ++i) {
for (int j = 0 ; j < axis_mul_after; ++j) {
sorter.clear();
int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
for (int64_t k = 0; k < input->shape[axis]; ++k) {
int64_t full_idx = base_idx + k * axis_mul_after;
sorter.emplace_back(std::make_pair(k, data_ptr[full_idx]));
}
if (is_ascend) {
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>);
} else {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
}
for (int64_t k = 0; k < input->shape[axis]; ++k) {
out_ptr[base_idx + k * axis_mul_after] = static_cast<OutType>(sorter[k].first);
}
}
}
}
// Argsort implemented C library sort.
// Return indices of sorted tensor.
......@@ -124,25 +159,84 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
DLTensor *output = args[1];
int32_t axis = args[2];
bool is_ascend = args[3];
auto dtype = input->dtype;
auto data_ptr = static_cast<float *>(input->data);
std::vector<std::pair<float, float>> sorter;
int64_t axis_mul_before = 1;
int64_t axis_mul_after = 1;
if (axis < 0) {
axis = input->ndim + axis;
}
// Currently only supports input dtype to be float32.
CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
"to be float32.";
CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
"to be float32.";
CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
"input ndim " << input->ndim;
auto data_dtype = TVMType2String(input->dtype);
auto out_dtype = TVMType2String(output->dtype);
if (data_dtype == "float32") {
if (out_dtype == "int32") {
argsort<float, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
argsort<float, int64_t>(input, output, axis, is_ascend);
} else if (out_dtype == "float32") {
argsort<float, float>(input, output, axis, is_ascend);
} else if (out_dtype == "float64") {
argsort<float, double>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
argsort<double, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
argsort<double, int64_t>(input, output, axis, is_ascend);
} else if (out_dtype == "float32") {
argsort<double, float>(input, output, axis, is_ascend);
} else if (out_dtype == "float64") {
argsort<double, double>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
argsort<int32_t, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
argsort<int32_t, int64_t>(input, output, axis, is_ascend);
} else if (out_dtype == "float32") {
argsort<int32_t, float>(input, output, axis, is_ascend);
} else if (out_dtype == "float64") {
argsort<int32_t, double>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
argsort<int64_t, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
argsort<int64_t, int64_t>(input, output, axis, is_ascend);
} else if (out_dtype == "float32") {
argsort<int64_t, float>(input, output, axis, is_ascend);
} else if (out_dtype == "float64") {
argsort<int64_t, double>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
});
template<typename DataType, typename IndicesType>
void topk(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
int k,
int axis,
bool is_ascend) {
DataType* data_ptr = static_cast<DataType *>(input->data);
DataType* values_ptr = (out_values == nullptr) ? nullptr :
static_cast<DataType *>(out_values->data);
IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr :
static_cast<IndicesType *>(out_indices->data);
std::vector<std::pair<int64_t, DataType> > sorter;
int axis_mul_before = 1;
int axis_mul_after = 1;
for (int i = 0; i < input->ndim; ++i) {
if (i < axis) {
axis_mul_before *= input->shape[i];
......@@ -150,26 +244,123 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
axis_mul_after *= input->shape[i];
}
}
if (k < 1) {
k = input->shape[axis];
}
int32_t current_sort_num = input->shape[axis];
for (int64_t i = 0 ; i < axis_mul_before; ++i) {
for (int64_t j = 0 ; j < axis_mul_after; ++j) {
for (int i = 0 ; i < axis_mul_before; ++i) {
for (int j = 0 ; j < axis_mul_after; ++j) {
sorter.clear();
int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
for (int64_t k = 0; k < current_sort_num; ++k) {
int64_t full_idx = base_idx + k * axis_mul_after;
sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j;
int64_t dst_base_idx = i * k * axis_mul_after + j;
for (int64_t kk = 0; kk < input->shape[axis]; ++kk) {
int64_t full_idx = src_base_idx + kk * axis_mul_after;
sorter.emplace_back(std::make_pair(kk, data_ptr[full_idx]));
}
if (is_ascend) {
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>);
} else {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
}
for (int32_t k = 0; k < input->shape[axis]; ++k) {
*(static_cast<float *>(output->data) + base_idx + k * axis_mul_after)
= k < static_cast<float>(sorter.size()) ? sorter[k].first : k;
int64_t cnt = k > 0 ? k : input->shape[axis];
for (int64_t kk = 0; kk < cnt; ++kk) {
if (indices_ptr != nullptr) {
indices_ptr[dst_base_idx + kk * axis_mul_after] =
static_cast<IndicesType>(sorter[kk].first);
}
if (values_ptr != nullptr) {
values_ptr[dst_base_idx + kk * axis_mul_after] =
static_cast<DataType>(sorter[kk].second);
}
}
}
}
}
// Argsort implemented C library sort.
// Return indices of sorted tensor.
// By default, the last axis will be used to sort.
// sort_num specify the number of elements to be sorted.
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
// and sort axis is dk. sort_num should have dimension of
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* input = args[0];
DLTensor* values_out = nullptr;
DLTensor* indices_out = nullptr;
int k = args[args.num_args - 4];
int axis = args[args.num_args - 3];
std::string ret_type = args[args.num_args - 2];
bool is_ascend = args[args.num_args - 1];
if (ret_type == "both") {
values_out = args[1];
indices_out = args[2];
} else if (ret_type == "values") {
values_out = args[1];
} else if (ret_type == "indices") {
indices_out = args[1];
} else {
LOG(FATAL) << "Unsupported ret type: " << ret_type;
}
if (axis < 0) {
axis = input->ndim + axis;
}
CHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim;
auto data_dtype = TVMType2String(input->dtype);
auto out_dtype = (indices_out == nullptr) ? "int64" : TVMType2String(indices_out->dtype);
if (data_dtype == "float32") {
if (out_dtype == "int32") {
topk<float, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
topk<float, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float32") {
topk<float, float>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float64") {
topk<float, double>(input, values_out, indices_out, k, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
topk<double, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
topk<double, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float32") {
topk<double, float>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float64") {
topk<double, double>(input, values_out, indices_out, k, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
topk<int32_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
topk<int32_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float32") {
topk<int32_t, float>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float64") {
topk<int32_t, double>(input, values_out, indices_out, k, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
topk<int64_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
topk<int64_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float32") {
topk<int64_t, float>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float64") {
topk<int64_t, double>(input, values_out, indices_out, k, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
});
......
......@@ -18,9 +18,9 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nms.cc
* \brief Non-maximum suppression operators
* Copyright (c) 2019 by Contributors
* \file argsort.cc
* \brief Argsort operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h>
......@@ -44,7 +44,6 @@ bool ArgsortRel(const Array<Type>& types,
<< types[0];
return false;
}
CHECK_EQ(param->dtype, Float(32));
reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
return true;
}
......@@ -74,5 +73,6 @@ input array along the given axis.
.add_argument("data", "Tensor", "Input data.")
.set_support_level(6)
.add_type_rel("Argsort", ArgsortRel);
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file topk.cc
* \brief TopK operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(TopKAttrs);
bool TopKRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
const TopKAttrs* param = attrs.as<TopKAttrs>();
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data);
int ndim = data->shape.size();
int axis = param->axis;
if (axis < 0) {
axis += ndim;
}
CHECK(axis >= 0 && axis < ndim);
Array<IndexExpr> out_shape;
for (int i = 0; i < ndim; ++i) {
if (i != axis || param->k < 1) {
out_shape.push_back(data->shape[i]);
} else {
out_shape.push_back(param->k);
}
}
auto values_ty = TensorTypeNode::make(out_shape, data->dtype);
auto indices_ty = TensorTypeNode::make(out_shape, param->dtype);
if (param->ret_type == "both") {
reporter->Assign(types[1], TupleTypeNode::make({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
reporter->Assign(types[1], values_ty);
} else if (param->ret_type == "indices") {
reporter->Assign(types[1], indices_ty);
} else {
LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
}
return true;
}
Expr MakeTopK(Expr data,
int k,
int axis,
std::string ret_type,
bool is_ascend,
DataType dtype) {
auto attrs = make_node<TopKAttrs>();
attrs->k = k;
attrs->axis = axis;
attrs->ret_type = ret_type;
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("topk");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.topk")
.set_body_typed(MakeTopK);
RELAY_REGISTER_OP("topk")
.describe(R"doc(Get the top k elements in an input tensor along the given axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.TopKAttrs")
.add_argument("data", "Tensor", "Input data.")
.set_support_level(6)
.add_type_rel("TopK", TopKRel);
} // namespace relay
} // namespace tvm
......@@ -608,6 +608,45 @@ def test_forward_Crop():
verify((5, 32, 40, 40), (5, 32, 25, 25))
verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5))
def test_forward_argsort():
def verify(shape, axis, is_ascend, dtype="float32"):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.argsort(mx.nd.array(x_np), axis=axis, is_ascend=is_ascend, dtype=dtype)
mx_sym = mx.sym.argsort(mx.sym.var("x"), axis=axis, is_ascend=is_ascend, dtype=dtype)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 3, 4), axis=0, is_ascend=False)
verify((1, 4, 6), axis=1, is_ascend=True)
verify((3, 5, 6), axis=-3, is_ascend=False, dtype="int32")
def test_forward_topk():
def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type,
is_ascend=is_ascend, dtype=dtype)
mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type,
is_ascend=is_ascend, dtype=dtype)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
if isinstance(ref_res, list):
assert len(op_res) == len(ref_res)
for i, t in enumerate(op_res):
tvm.testing.assert_allclose(t.asnumpy(), ref_res[i].asnumpy())
else:
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((3, 4), k=1, axis=0, ret_type="both")
verify((3, 4), k=1, axis=-1, ret_type="indices")
verify((3, 5, 6), k=2, axis=2, ret_type="value")
verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True)
verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")
if __name__ == '__main__':
test_forward_mlp()
......@@ -650,3 +689,5 @@ if __name__ == '__main__':
test_forward_bilinear_resize()
test_forward_rnn_layer()
test_forward_Crop()
test_forward_argsort()
test_forward_topk()
......@@ -754,6 +754,24 @@ def test_forward_split():
_test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
######################################################################
# TopKV2
# ------
def _test_forward_top_k_v2(in_shape, k):
np_data = np.random.uniform(-100, 100, size=in_shape).astype("float32")
tf.reset_default_graph()
in_data = tf.placeholder("float32", in_shape, name="in_data")
tf.math.top_k(in_data, k, name='TopK')
compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')
def test_forward_top_k_v2():
_test_forward_top_k_v2((3,), 1)
_test_forward_top_k_v2((3,), 3)
_test_forward_top_k_v2((3, 5, 7), 3)
_test_forward_top_k_v2((3, 5, 7), 3)
#######################################################################
# Unstack
# -------
......@@ -1704,6 +1722,7 @@ if __name__ == '__main__':
test_forward_split()
test_forward_unstack()
test_forward_tile()
test_forward_top_k_v2()
# Activations
test_forward_sigmoid()
......
......@@ -16,18 +16,15 @@
# under the License.
""" Support level6 operator test cases.
"""
import math
import numpy as np
import tvm
from tvm import relay
from tvm.relay.testing import ctx_list
import topi.testing
def test_argsort():
def verify_argsort(shape, axis, is_ascend):
def verify_argsort(shape, axis, is_ascend, dtype):
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.argsort(x, axis=axis, is_ascend=is_ascend)
zz = relay.ir_pass.infer_type(z)
z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(size=shape).astype("float32")
if is_ascend:
......@@ -39,11 +36,58 @@ def test_argsort():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5)
verify_argsort((2, 3, 4), axis=0, is_ascend=False)
verify_argsort((1, 4, 6), axis=1, is_ascend=True)
verify_argsort((3, 5, 6), axis=-1, is_ascend=False)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype(dtype), rtol=1e-5)
for dtype in ["int32", "int64", "float32", "float64"]:
verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype)
verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype)
verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype)
def test_topk():
def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100)
x = relay.var("x", relay.TensorType(shape, "float32"))
out = relay.topk(x, k, axis, ret_type, is_ascend, dtype)
if isinstance(out, relay.expr.TupleWrapper):
out = out.astuple()
func = relay.Function([x], out)
np_data = np.random.uniform(size=shape).astype("float32")
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
else:
np_indices = np.argsort(-np_data, axis=axis)
kk = k if k >= 1 else shape[axis]
if axis == 0:
np_indices = np_indices[:kk, :]
np_values = np.zeros(np_indices.shape).astype("float32")
for i in range(shape[1]):
np_values[:, i] = np_data[np_indices[:, i], i]
else:
np_indices = np_indices[:, :kk]
np_values = np.zeros(np_indices.shape).astype("float32")
for i in range(shape[0]):
np_values[i, :] = np_data[i, np_indices[i, :]]
np_indices = np_indices.astype(dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(np_data)
if ret_type == "both":
tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values)
tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices)
elif ret_type == "values":
tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
else:
tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
for k in [0, 1, 5]:
for axis in [0, -1, 1]:
for ret_type in ["both", "values", "indices"]:
for dtype in ["int64", "float32"]:
verify_topk(k, axis, ret_type, False, dtype)
verify_topk(k, axis, ret_type, True, dtype)
if __name__ == "__main__":
test_argsort()
test_topk()
......@@ -21,3 +21,4 @@ from . import ssd
from .ssd import *
from .nms import *
from .rcnn import *
from .sort import *
......@@ -732,7 +732,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8)
......
......@@ -19,19 +19,48 @@
import tvm
from tvm import api
from topi.sort import argsort
from topi.math import identity
from ..sort import argsort, topk
from ..math import identity
from ..transform import strided_slice
from .. import generic
from .. import tag
def _schedule_sort(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
if tag.is_injective(op.tag):
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
for out in outs:
traverse(out.op)
return s
def sort_ir(data, output, axis, is_ascend):
def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
Parameters
----------
data: Buffer
Buffer of input data.
Buffer of input data. Data will be sorted in place.
output : Buffer
Output buffer of indicies of sorted tensor with same shape as data.
......@@ -47,14 +76,12 @@ def sort_ir(data, output, axis, is_ascend):
stmt : Stmt
The result IR statement.
"""
size = 1
axis_mul_before = 1
axis_mul_after = 1
shape = data.shape
if axis < 0:
axis = len(shape) + axis
for i, value in enumerate(shape, 0):
size *= value
if i < axis:
axis_mul_before *= value
elif i > axis:
......@@ -62,52 +89,62 @@ def sort_ir(data, output, axis, is_ascend):
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create()
data = ib.buffer_ptr(data)
output = ib.buffer_ptr(output)
values_out = ib.buffer_ptr(values_out)
if indices_out is not None:
indices_out = ib.buffer_ptr(indices_out)
nthread_tx = max_threads
nthread_bx = size // max_threads + 1
nthread_bx = shape[axis] // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("vthread")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "virtual_thread", nthread_bx)
tid = bx * nthread_tx + tx
temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
temp_index = ib.allocate("float32", (1,), name="temp_index", scope="local")
is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend)
temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local")
if indices_out is not None:
temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local")
with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j:
current_sort_num = shape[axis]
base_idx = i * shape[axis] * axis_mul_after + j
with ib.if_scope(tid < shape[axis]):
output[base_idx + tid * axis_mul_after] = tid.astype("float32")
values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after]
if indices_out is not None:
indices_out[base_idx + tid * axis_mul_after] = \
tvm.generic.cast(tid, indices_out.dtype)
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j:
current_sort_num = shape[axis]
base_idx = i * shape[axis] * axis_mul_after + j
# OddEvenTransposeSort
with ib.for_range(0, current_sort_num) as k:
with ib.if_scope(tid < (current_sort_num + 1) // 2):
offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
with ib.if_scope(tvm.all(is_ascend == 1, \
2 * tid + (k % 2) + 1 < current_sort_num, \
data[offset] > data[offset + axis_mul_after])):
temp_data[0] = data[offset]
data[offset] = data[offset + axis_mul_after]
data[offset + axis_mul_after] = temp_data[0]
temp_index[0] = output[offset]
output[offset] = output[offset + axis_mul_after]
output[offset + axis_mul_after] = temp_index[0]
with ib.if_scope(tvm.all(is_ascend == 0, \
2 * tid + (k % 2) + 1 < current_sort_num, \
data[offset] < data[offset + axis_mul_after])):
temp_data[0] = data[offset]
data[offset] = data[offset + axis_mul_after]
data[offset + axis_mul_after] = temp_data[0]
temp_index[0] = output[offset]
output[offset] = output[offset + axis_mul_after]
output[offset + axis_mul_after] = temp_index[0]
if is_ascend:
cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num,
values_out[offset] > values_out[offset + axis_mul_after])
else:
cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num,
values_out[offset] < values_out[offset + axis_mul_after])
with ib.if_scope(cond):
temp_data[0] = values_out[offset]
values_out[offset] = values_out[offset + axis_mul_after]
values_out[offset + axis_mul_after] = temp_data[0]
if indices_out is not None:
temp_index[0] = indices_out[offset]
indices_out[offset] = indices_out[offset + axis_mul_after]
indices_out[offset + axis_mul_after] = temp_index[0]
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
return ib.get()
def sort_nms_ir(data, valid_count, output, axis, is_ascend):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
......@@ -197,7 +234,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
return ib.get()
@argsort.register(["cuda", "gpu"])
def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
def argsort_gpu(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
......@@ -206,26 +243,27 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0
data: tvm.Tensor
The input array.
valid_count : tvm.Tensor
valid_count : tvm.Tensor, optional
The number of valid elements to be sorted.
axis : int
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
flag : boolean
Whether this argsort is used in nms operator
dtype : string, optional
DType of the output indices.
Returns
-------
out : tvm.Tensor
The output of this function.
"""
sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8)
if valid_count is not None:
sorted_data = identity(data)
if flag:
sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf",
data_alignment=8)
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)
......@@ -239,16 +277,15 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0
name="argsort_nms_gpu",
tag="argsort_nms_gpu")
else:
out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
out = tvm.extern([data.shape],
[sorted_data],
value_buf = api.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
indices_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
out = tvm.extern([data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(
ins[0], outs[0], axis, is_ascend),
dtype=dtype,
in_buffers=[sorted_data_buf],
out_buffers=[out_buf],
ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
out_buffers=[value_buf, indices_buf],
name="argsort_gpu",
tag="argsort_gpu")
tag="argsort_gpu")[1]
return out
@generic.schedule_argsort.register(["cuda", "gpu"])
......@@ -266,17 +303,99 @@ def schedule_argsort(outs):
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
if tag.is_broadcast(op.tag):
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return _schedule_sort(outs)
return s
@topk.register(["cuda", "gpu"])
def topk_gpu(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.
Parameters
----------
data : tvm.Tensor
The input tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : tvm.Tensor or List[tvm.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
ndim = len(data.shape)
axis = axis + ndim if axis < 0 else axis
assert 0 <= axis < ndim
values_buf = api.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
indices_buf = api.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
if ret_type == "values":
output = tvm.extern([data.shape],
[data],
lambda ins, outs: sort_ir(
ins[0], outs[0], axis, is_ascend),
out_buffers=[values_buf],
name="topk_gpu",
tag="topk_gpu")
else:
output = tvm.extern([data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(
ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
out_buffers=[values_buf, indices_buf],
name="topk_gpu",
tag="topk_gpu")
if k < 1:
if ret_type == "indices":
return output[1]
return output
beg = [0] * ndim
end = []
for i in range(ndim):
if i == axis:
end.append(k)
else:
end.append(data.shape[i])
if ret_type == "both":
values_out, indices_out = output
values_out = strided_slice(values_out, beg, end)
indices_out = strided_slice(indices_out, beg, end)
output = [values_out, indices_out]
elif ret_type == "values":
output = [strided_slice(output, beg, end)]
else: # ret_type == "indices"
indices_out = output[1]
output = [strided_slice(indices_out, beg, end)]
return output
@generic.schedule_topk.register(["cuda", "gpu"])
def schedule_topk(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _schedule_sort(outs)
......@@ -36,3 +36,20 @@ def schedule_argsort(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_topk(outs):
"""Schedule for topk operator.
Parameters
----------
outs: Array of Tensor
The indices that would sort an input array along
the given axis.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
......@@ -18,9 +18,10 @@
"""Argsort operator"""
import tvm
from tvm import api
from .util import get_const_tuple
@tvm.target.generic_func
def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array
of indices having the same shape as an input array that index
data in sorted order.
......@@ -30,22 +31,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
data : tvm.Tensor
The input tensor.
valid_count : tvm.Tensor
valid_count : tvm.Tensor, optional
1-D tensor for valid number of boxes only for ssd.
axis : optional, int
axis : int, optional
Axis along which to sort the input tensor.
By default the flattened array is used.
is_ascend : optional, boolean
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : optional, string
dtype : string, optional
DType of the output indices.
flag : optional, boolean
Whether valid_count is valid.
Returns
-------
out : tvm.Tensor
......@@ -58,23 +56,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
# An example to use argsort
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
axis = 0
is_ascend = False
flag = False
out = argsort(data, valid_count, axis, is_ascend, flag)
out = argsort(data, axis=axis, is_ascend=is_ascend)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_argsort(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
f = tvm.build(s, [data, out], "llvm")
ctx = tvm.cpu()
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f(tvm_data, tvm_valid_count, tvm_out)
f(tvm_data, tvm_out)
"""
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
if flag:
if valid_count is not None:
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8)
......@@ -103,3 +97,58 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
name="argsort_cpu",
tag="argsort_cpu")
return out
@tvm.target.generic_func
def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.
Parameters
----------
data : tvm.Tensor
The input tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : tvm.Tensor or List[tvm.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_shape = list(get_const_tuple(data.shape))
if k >= 1:
out_shape[axis] = k
out_bufs = []
if ret_type in ["both", "values"]:
out_bufs.append(api.decl_buffer(out_shape, data.dtype, "value_buf", data_alignment=8))
if ret_type in ["both", "indices"]:
out_bufs.append(api.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8))
out_shapes = [out_shape] * len(out_bufs)
out = tvm.extern(out_shapes,
[data],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.topk", ins[0], *outs, k, axis, ret_type, is_ascend),
in_buffers=[data_buf],
out_buffers=out_bufs,
name="topk_cpu",
tag="topk_cpu")
return out
......@@ -151,6 +151,8 @@ def strided_slice(a, begin, end, strides=None):
-------
ret : tvm.Tensor
"""
if strides is None:
strides = []
return cpp.strided_slice(a, begin, end, strides)
......
......@@ -331,7 +331,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
tvm.const(max_output_size, dtype="int32"),
tvm.const(iou_threshold, dtype="float32"),
......
......@@ -16,23 +16,15 @@
# under the License.
"""Test code for vision package"""
from __future__ import print_function
import math
import numpy as np
import tvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from topi import argsort
def test_argsort():
dshape = (1, 8)
valid_count_shape = (2,)
dshape = (20, 100)
data = tvm.placeholder(dshape, name="data", dtype="float32")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.argsort(-np_data)
def check_device(device):
ctx = tvm.context(device, 0)
......@@ -41,19 +33,77 @@ def test_argsort():
return
print("Running on target: %s" % device)
with tvm.target.create(device):
out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False)
out = topi.argsort(data, axis=-1, is_ascend=False)
s = topi.generic.schedule_argsort(out)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx)
f = tvm.build(s, [data, valid_count, out], device)
f(tvm_data, tvm_valid_count, tvm_out)
f = tvm.build(s, [data, out], device)
f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0)
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100)
data_dtype = "float32"
data = tvm.placeholder(shape, name="data", dtype=data_dtype)
np_data = np.random.uniform(size=shape).astype(data_dtype)
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
else:
np_indices = np.argsort(-np_data, axis=axis)
kk = k if k >= 1 else shape[axis]
if axis == 0:
np_indices = np_indices[:kk, :]
np_values = np.zeros(np_indices.shape).astype(data_dtype)
for i in range(shape[1]):
np_values[:, i] = np_data[np_indices[:, i], i]
else:
np_indices = np_indices[:, :kk]
np_values = np.zeros(np_indices.shape).astype(data_dtype)
for i in range(shape[0]):
np_values[i, :] = np_data[i, np_indices[i, :]]
np_indices = np_indices.astype(dtype)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
outs = topi.topk(data, k, axis, ret_type, is_ascend, dtype)
outs = outs if isinstance(outs, list) else [outs]
s = topi.generic.schedule_topk(outs)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_res = []
for t in outs:
tvm_res.append(tvm.nd.empty(t.shape, dtype=t.dtype, ctx=ctx))
f = tvm.build(s, [data] + outs, device)
f(tvm_data, *tvm_res)
if ret_type == "both":
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values)
tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_indices)
elif ret_type == "values":
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values)
else:
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices)
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
def test_topk():
for k in [0, 1, 5]:
for axis in [0, -1, 1]:
for ret_type in ["both", "values", "indices"]:
for dtype in ["int64", "float32"]:
verify_topk(k, axis, ret_type, True, dtype)
verify_topk(k, axis, ret_type, False, dtype)
if __name__ == "__main__":
test_argsort()
test_topk()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment