Commit 072f8cc7 by Haichen Shen Committed by Leyuan Wang

[Relay/TOPI][Op] Add TopK operator (#3256)

* init impl for topk

* Fix cpu for topk

* init cuda impl for topk

* Add cuda for topk

* fix

* Add doc

* update doc

* lint

* lint

* lint

* x

* fix warning

* [Relay] Add TopK in tf converter

* Add frontend converter

* fix
parent 4204544b
......@@ -99,6 +99,8 @@ List of operators
topi.shape
topi.layout_transform
topi.image.resize
topi.argsort
topi.topk
List of schedules
......@@ -163,6 +165,8 @@ topi
.. autofunction:: topi.tile
.. autofunction:: topi.shape
.. autofunction:: topi.layout_transform
.. autofunction:: topi.argsort
.. autofunction:: topi.topk
topi.nn
~~~~~~~
......
......@@ -172,6 +172,7 @@ This level enables additional math and transform operators.
:nosignatures:
tvm.relay.argsort
tvm.relay.topk
**Level 10: Temporary Operators**
......@@ -309,6 +310,7 @@ Level 5 Definitions
Level 6 Definitions
-------------------
.. autofunction:: tvm.relay.argsort
.. autofunction:: tvm.relay.topk
Level 10 Definitions
......
......@@ -48,6 +48,31 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
}
};
struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
int k;
int axis;
bool is_ascend;
std::string ret_type;
DataType dtype;
TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
TVM_ATTR_FIELD(k).set_default(1)
.describe("Number of top elements to select");
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis along which to sort the input tensor.");
TVM_ATTR_FIELD(ret_type).set_default("both")
.describe("The return type [both, values, indices]."
"both - return both top k data and indices."
"values - return top k data only."
"indices - return top k indices only.");
TVM_ATTR_FIELD(is_ascend).set_default(false)
.describe("Whether to sort in ascending or descending order."
"By default, sort in descending order");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Data type of the output indices.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ALGORITHM_H_
......@@ -683,6 +683,21 @@ def _mx_argsort(inputs, attrs):
return _op.argsort(inputs[0], **new_attrs)
def _mx_topk(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["k"] = attrs.get_int("k", 1)
new_attrs["axis"] = attrs.get_int("axis", -1)
new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
ret_type = attrs.get_str("ret_typ", "indices")
if ret_type == "mask":
raise tvm.error.OpAttributeUnimplemented(
"Attribute ret_type=mask is not supported in topk operator")
new_attrs["ret_type"] = "values" if ret_type == "value" else ret_type
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.topk(inputs[0], **new_attrs)
def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs]
......@@ -914,6 +929,7 @@ _convert_map = {
"shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"topk" : _mx_topk,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output,
......
......@@ -1082,6 +1082,20 @@ def _softplus():
return _get_relay_op('log')(add_out)
return _impl
def _topk():
def _impl(inputs, attr, params):
k = int(params.pop(inputs.pop(1).name_hint).asnumpy())
if k < 1:
raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2')
if attr['sorted'] is False:
raise tvm.error.OpAttributeUnimplemented(
'Attribute sorted=False is not supported in operator TopKV2')
return AttrCvt(op_name='topk',
ignores=['sorted'],
extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr)
return _impl
def _logical(name):
def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
......@@ -1271,6 +1285,7 @@ _convert_map = {
'Sum' : _sum(),
'Tanh' : AttrCvt('tanh'),
'Tile' : _tile(),
'TopKV2' : _topk(),
'Transpose' : _transpose(),
'Unpack' : _unpack(),
......
......@@ -35,11 +35,31 @@ def compute_argsort(attrs, inputs, _, target):
"""Compute definition of argsort"""
axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = str(attrs.dtype)
return [
topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \
dtype=dtype, flag=False)
]
dtype = attrs.dtype
return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
register_pattern("argsort", OpPattern.OPAQUE)
@register_schedule("topk")
def schedule_topk(_, outs, target):
"""Schedule definition of argsort"""
with target:
return topi.generic.schedule_topk(outs)
@register_compute("topk")
def compute_topk(attrs, inputs, _, target):
"""Compute definition of argsort"""
k = get_const_int(attrs.k)
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype)
out = out if isinstance(out, list) else [out]
return out
register_pattern("topk", OpPattern.OPAQUE)
......@@ -17,8 +17,9 @@
"""Classic algorithm operation"""
from __future__ import absolute_import as _abs
from . import _make
from ..expr import TupleWrapper
def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
......@@ -37,7 +38,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
The data type of the output indices.
Returns
-------
......@@ -45,3 +46,42 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
Tensor with same shape as data.
"""
return _make.argsort(data, axis, is_ascend, dtype)
def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
"""Get the top k elements in an input tensor along the given axis.
ret_type specifies the return type, can be one of ("both", "values", "indices").
Parameters
----------
data : relay.Expr
The input data tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : relay.Expr or List[relay.Expr]
The computed result.
"""
out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
if ret_type == "both":
return TupleWrapper(out, 2)
return out
......@@ -401,7 +401,7 @@ def upsampling(data,
with data of shape (n, c, h, w)
out will have a shape (n, c, h*scale, w*scale)
method indicates the algorithm to be used while calculating ghe out value
method indicates the algorithm to be used while calculating the out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR")
Parameters
......
......@@ -218,9 +218,9 @@ def take(data, indices, axis=None, mode="clip"):
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
Specifies how out-of-bound indices will behave [clip, wrap].
clip: clip to the range (default).
wrap: wrap around the indices.
Returns
-------
......
......@@ -83,7 +83,7 @@ Target CreateTarget(const std::string& target_name,
t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImm::make("cuda"));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 512;
t->max_num_threads = 1024;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
// For now assume rocm schedule for opencl
......
......@@ -18,9 +18,9 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nms.cc
* \brief Non-maximum suppression operators
* Copyright (c) 2019 by Contributors
* \file argsort.cc
* \brief Argsort operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h>
......@@ -44,7 +44,6 @@ bool ArgsortRel(const Array<Type>& types,
<< types[0];
return false;
}
CHECK_EQ(param->dtype, Float(32));
reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
return true;
}
......@@ -74,5 +73,6 @@ input array along the given axis.
.add_argument("data", "Tensor", "Input data.")
.set_support_level(6)
.add_type_rel("Argsort", ArgsortRel);
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file topk.cc
* \brief TopK operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(TopKAttrs);
bool TopKRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
const TopKAttrs* param = attrs.as<TopKAttrs>();
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data);
int ndim = data->shape.size();
int axis = param->axis;
if (axis < 0) {
axis += ndim;
}
CHECK(axis >= 0 && axis < ndim);
Array<IndexExpr> out_shape;
for (int i = 0; i < ndim; ++i) {
if (i != axis || param->k < 1) {
out_shape.push_back(data->shape[i]);
} else {
out_shape.push_back(param->k);
}
}
auto values_ty = TensorTypeNode::make(out_shape, data->dtype);
auto indices_ty = TensorTypeNode::make(out_shape, param->dtype);
if (param->ret_type == "both") {
reporter->Assign(types[1], TupleTypeNode::make({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
reporter->Assign(types[1], values_ty);
} else if (param->ret_type == "indices") {
reporter->Assign(types[1], indices_ty);
} else {
LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
}
return true;
}
Expr MakeTopK(Expr data,
int k,
int axis,
std::string ret_type,
bool is_ascend,
DataType dtype) {
auto attrs = make_node<TopKAttrs>();
attrs->k = k;
attrs->axis = axis;
attrs->ret_type = ret_type;
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("topk");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.topk")
.set_body_typed(MakeTopK);
RELAY_REGISTER_OP("topk")
.describe(R"doc(Get the top k elements in an input tensor along the given axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.TopKAttrs")
.add_argument("data", "Tensor", "Input data.")
.set_support_level(6)
.add_type_rel("TopK", TopKRel);
} // namespace relay
} // namespace tvm
......@@ -608,6 +608,45 @@ def test_forward_Crop():
verify((5, 32, 40, 40), (5, 32, 25, 25))
verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5))
def test_forward_argsort():
def verify(shape, axis, is_ascend, dtype="float32"):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.argsort(mx.nd.array(x_np), axis=axis, is_ascend=is_ascend, dtype=dtype)
mx_sym = mx.sym.argsort(mx.sym.var("x"), axis=axis, is_ascend=is_ascend, dtype=dtype)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 3, 4), axis=0, is_ascend=False)
verify((1, 4, 6), axis=1, is_ascend=True)
verify((3, 5, 6), axis=-3, is_ascend=False, dtype="int32")
def test_forward_topk():
def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type,
is_ascend=is_ascend, dtype=dtype)
mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type,
is_ascend=is_ascend, dtype=dtype)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
if isinstance(ref_res, list):
assert len(op_res) == len(ref_res)
for i, t in enumerate(op_res):
tvm.testing.assert_allclose(t.asnumpy(), ref_res[i].asnumpy())
else:
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((3, 4), k=1, axis=0, ret_type="both")
verify((3, 4), k=1, axis=-1, ret_type="indices")
verify((3, 5, 6), k=2, axis=2, ret_type="value")
verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True)
verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")
if __name__ == '__main__':
test_forward_mlp()
......@@ -650,3 +689,5 @@ if __name__ == '__main__':
test_forward_bilinear_resize()
test_forward_rnn_layer()
test_forward_Crop()
test_forward_argsort()
test_forward_topk()
......@@ -754,6 +754,24 @@ def test_forward_split():
_test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
######################################################################
# TopKV2
# ------
def _test_forward_top_k_v2(in_shape, k):
np_data = np.random.uniform(-100, 100, size=in_shape).astype("float32")
tf.reset_default_graph()
in_data = tf.placeholder("float32", in_shape, name="in_data")
tf.math.top_k(in_data, k, name='TopK')
compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')
def test_forward_top_k_v2():
_test_forward_top_k_v2((3,), 1)
_test_forward_top_k_v2((3,), 3)
_test_forward_top_k_v2((3, 5, 7), 3)
_test_forward_top_k_v2((3, 5, 7), 3)
#######################################################################
# Unstack
# -------
......@@ -1704,6 +1722,7 @@ if __name__ == '__main__':
test_forward_split()
test_forward_unstack()
test_forward_tile()
test_forward_top_k_v2()
# Activations
test_forward_sigmoid()
......
......@@ -16,18 +16,15 @@
# under the License.
""" Support level6 operator test cases.
"""
import math
import numpy as np
import tvm
from tvm import relay
from tvm.relay.testing import ctx_list
import topi.testing
def test_argsort():
def verify_argsort(shape, axis, is_ascend):
def verify_argsort(shape, axis, is_ascend, dtype):
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.argsort(x, axis=axis, is_ascend=is_ascend)
zz = relay.ir_pass.infer_type(z)
z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(size=shape).astype("float32")
if is_ascend:
......@@ -39,11 +36,58 @@ def test_argsort():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5)
verify_argsort((2, 3, 4), axis=0, is_ascend=False)
verify_argsort((1, 4, 6), axis=1, is_ascend=True)
verify_argsort((3, 5, 6), axis=-1, is_ascend=False)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype(dtype), rtol=1e-5)
for dtype in ["int32", "int64", "float32", "float64"]:
verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype)
verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype)
verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype)
def test_topk():
def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100)
x = relay.var("x", relay.TensorType(shape, "float32"))
out = relay.topk(x, k, axis, ret_type, is_ascend, dtype)
if isinstance(out, relay.expr.TupleWrapper):
out = out.astuple()
func = relay.Function([x], out)
np_data = np.random.uniform(size=shape).astype("float32")
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
else:
np_indices = np.argsort(-np_data, axis=axis)
kk = k if k >= 1 else shape[axis]
if axis == 0:
np_indices = np_indices[:kk, :]
np_values = np.zeros(np_indices.shape).astype("float32")
for i in range(shape[1]):
np_values[:, i] = np_data[np_indices[:, i], i]
else:
np_indices = np_indices[:, :kk]
np_values = np.zeros(np_indices.shape).astype("float32")
for i in range(shape[0]):
np_values[i, :] = np_data[i, np_indices[i, :]]
np_indices = np_indices.astype(dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(np_data)
if ret_type == "both":
tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values)
tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices)
elif ret_type == "values":
tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
else:
tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
for k in [0, 1, 5]:
for axis in [0, -1, 1]:
for ret_type in ["both", "values", "indices"]:
for dtype in ["int64", "float32"]:
verify_topk(k, axis, ret_type, False, dtype)
verify_topk(k, axis, ret_type, True, dtype)
if __name__ == "__main__":
test_argsort()
test_topk()
......@@ -21,3 +21,4 @@ from . import ssd
from .ssd import *
from .nms import *
from .rcnn import *
from .sort import *
......@@ -732,7 +732,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8)
......
......@@ -36,3 +36,20 @@ def schedule_argsort(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_topk(outs):
"""Schedule for topk operator.
Parameters
----------
outs: Array of Tensor
The indices that would sort an input array along
the given axis.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
......@@ -18,9 +18,10 @@
"""Argsort operator"""
import tvm
from tvm import api
from .util import get_const_tuple
@tvm.target.generic_func
def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array
of indices having the same shape as an input array that index
data in sorted order.
......@@ -30,22 +31,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
data : tvm.Tensor
The input tensor.
valid_count : tvm.Tensor
valid_count : tvm.Tensor, optional
1-D tensor for valid number of boxes only for ssd.
axis : optional, int
Axis along which to sort the input tensor.
axis : int, optional
Axis along which to sort the input tensor.
By default the flattened array is used.
is_ascend : optional, boolean
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : optional, string
dtype : string, optional
DType of the output indices.
flag : optional, boolean
Whether valid_count is valid.
Returns
-------
out : tvm.Tensor
......@@ -58,23 +56,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
# An example to use argsort
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
axis = 0
is_ascend = False
flag = False
out = argsort(data, valid_count, axis, is_ascend, flag)
out = argsort(data, axis=axis, is_ascend=is_ascend)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_argsort(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
f = tvm.build(s, [data, out], "llvm")
ctx = tvm.cpu()
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f(tvm_data, tvm_valid_count, tvm_out)
f(tvm_data, tvm_out)
"""
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
if flag:
if valid_count is not None:
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8)
......@@ -103,3 +97,58 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
name="argsort_cpu",
tag="argsort_cpu")
return out
@tvm.target.generic_func
def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.
Parameters
----------
data : tvm.Tensor
The input tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : tvm.Tensor or List[tvm.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_shape = list(get_const_tuple(data.shape))
if k >= 1:
out_shape[axis] = k
out_bufs = []
if ret_type in ["both", "values"]:
out_bufs.append(api.decl_buffer(out_shape, data.dtype, "value_buf", data_alignment=8))
if ret_type in ["both", "indices"]:
out_bufs.append(api.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8))
out_shapes = [out_shape] * len(out_bufs)
out = tvm.extern(out_shapes,
[data],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.topk", ins[0], *outs, k, axis, ret_type, is_ascend),
in_buffers=[data_buf],
out_buffers=out_bufs,
name="topk_cpu",
tag="topk_cpu")
return out
......@@ -151,6 +151,8 @@ def strided_slice(a, begin, end, strides=None):
-------
ret : tvm.Tensor
"""
if strides is None:
strides = []
return cpp.strided_slice(a, begin, end, strides)
......
......@@ -331,7 +331,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
tvm.const(max_output_size, dtype="int32"),
tvm.const(iou_threshold, dtype="float32"),
......
......@@ -16,23 +16,15 @@
# under the License.
"""Test code for vision package"""
from __future__ import print_function
import math
import numpy as np
import tvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from topi import argsort
def test_argsort():
dshape = (1, 8)
valid_count_shape = (2,)
dshape = (20, 100)
data = tvm.placeholder(dshape, name="data", dtype="float32")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.argsort(-np_data)
def check_device(device):
ctx = tvm.context(device, 0)
......@@ -41,19 +33,77 @@ def test_argsort():
return
print("Running on target: %s" % device)
with tvm.target.create(device):
out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False)
out = topi.argsort(data, axis=-1, is_ascend=False)
s = topi.generic.schedule_argsort(out)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx)
f = tvm.build(s, [data, valid_count, out], device)
f(tvm_data, tvm_valid_count, tvm_out)
f = tvm.build(s, [data, out], device)
f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0)
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100)
data_dtype = "float32"
data = tvm.placeholder(shape, name="data", dtype=data_dtype)
np_data = np.random.uniform(size=shape).astype(data_dtype)
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
else:
np_indices = np.argsort(-np_data, axis=axis)
kk = k if k >= 1 else shape[axis]
if axis == 0:
np_indices = np_indices[:kk, :]
np_values = np.zeros(np_indices.shape).astype(data_dtype)
for i in range(shape[1]):
np_values[:, i] = np_data[np_indices[:, i], i]
else:
np_indices = np_indices[:, :kk]
np_values = np.zeros(np_indices.shape).astype(data_dtype)
for i in range(shape[0]):
np_values[i, :] = np_data[i, np_indices[i, :]]
np_indices = np_indices.astype(dtype)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
outs = topi.topk(data, k, axis, ret_type, is_ascend, dtype)
outs = outs if isinstance(outs, list) else [outs]
s = topi.generic.schedule_topk(outs)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_res = []
for t in outs:
tvm_res.append(tvm.nd.empty(t.shape, dtype=t.dtype, ctx=ctx))
f = tvm.build(s, [data] + outs, device)
f(tvm_data, *tvm_res)
if ret_type == "both":
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values)
tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_indices)
elif ret_type == "values":
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values)
else:
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices)
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
def test_topk():
for k in [0, 1, 5]:
for axis in [0, -1, 1]:
for ret_type in ["both", "values", "indices"]:
for dtype in ["int64", "float32"]:
verify_topk(k, axis, ret_type, True, dtype)
verify_topk(k, axis, ret_type, False, dtype)
if __name__ == "__main__":
test_argsort()
test_topk()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment