Commit 072f8cc7 by Haichen Shen Committed by Leyuan Wang

[Relay/TOPI][Op] Add TopK operator (#3256)

* init impl for topk

* Fix cpu for topk

* init cuda impl for topk

* Add cuda for topk

* fix

* Add doc

* update doc

* lint

* lint

* lint

* x

* fix warning

* [Relay] Add TopK in tf converter

* Add frontend converter

* fix
parent 4204544b
...@@ -99,6 +99,8 @@ List of operators ...@@ -99,6 +99,8 @@ List of operators
topi.shape topi.shape
topi.layout_transform topi.layout_transform
topi.image.resize topi.image.resize
topi.argsort
topi.topk
List of schedules List of schedules
...@@ -163,6 +165,8 @@ topi ...@@ -163,6 +165,8 @@ topi
.. autofunction:: topi.tile .. autofunction:: topi.tile
.. autofunction:: topi.shape .. autofunction:: topi.shape
.. autofunction:: topi.layout_transform .. autofunction:: topi.layout_transform
.. autofunction:: topi.argsort
.. autofunction:: topi.topk
topi.nn topi.nn
~~~~~~~ ~~~~~~~
......
...@@ -172,6 +172,7 @@ This level enables additional math and transform operators. ...@@ -172,6 +172,7 @@ This level enables additional math and transform operators.
:nosignatures: :nosignatures:
tvm.relay.argsort tvm.relay.argsort
tvm.relay.topk
**Level 10: Temporary Operators** **Level 10: Temporary Operators**
...@@ -309,6 +310,7 @@ Level 5 Definitions ...@@ -309,6 +310,7 @@ Level 5 Definitions
Level 6 Definitions Level 6 Definitions
------------------- -------------------
.. autofunction:: tvm.relay.argsort .. autofunction:: tvm.relay.argsort
.. autofunction:: tvm.relay.topk
Level 10 Definitions Level 10 Definitions
......
...@@ -48,6 +48,31 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> { ...@@ -48,6 +48,31 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
} }
}; };
struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
int k;
int axis;
bool is_ascend;
std::string ret_type;
DataType dtype;
TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
TVM_ATTR_FIELD(k).set_default(1)
.describe("Number of top elements to select");
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis along which to sort the input tensor.");
TVM_ATTR_FIELD(ret_type).set_default("both")
.describe("The return type [both, values, indices]."
"both - return both top k data and indices."
"values - return top k data only."
"indices - return top k indices only.");
TVM_ATTR_FIELD(is_ascend).set_default(false)
.describe("Whether to sort in ascending or descending order."
"By default, sort in descending order");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Data type of the output indices.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_ALGORITHM_H_ #endif // TVM_RELAY_ATTRS_ALGORITHM_H_
...@@ -683,6 +683,21 @@ def _mx_argsort(inputs, attrs): ...@@ -683,6 +683,21 @@ def _mx_argsort(inputs, attrs):
return _op.argsort(inputs[0], **new_attrs) return _op.argsort(inputs[0], **new_attrs)
def _mx_topk(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["k"] = attrs.get_int("k", 1)
new_attrs["axis"] = attrs.get_int("axis", -1)
new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
ret_type = attrs.get_str("ret_typ", "indices")
if ret_type == "mask":
raise tvm.error.OpAttributeUnimplemented(
"Attribute ret_type=mask is not supported in topk operator")
new_attrs["ret_type"] = "values" if ret_type == "value" else ret_type
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.topk(inputs[0], **new_attrs)
def _mx_rnn_param_concat(inputs, _): def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op # We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs] return [inputs]
...@@ -914,6 +929,7 @@ _convert_map = { ...@@ -914,6 +929,7 @@ _convert_map = {
"shape_array" : _mx_shape_array, "shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding, "Embedding" : _mx_embedding,
"argsort" : _mx_argsort, "argsort" : _mx_argsort,
"topk" : _mx_topk,
"SoftmaxOutput" : _mx_softmax_output, "SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation, "SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output, "LinearRegressionOutput" : _mx_linear_regression_output,
......
...@@ -1082,6 +1082,20 @@ def _softplus(): ...@@ -1082,6 +1082,20 @@ def _softplus():
return _get_relay_op('log')(add_out) return _get_relay_op('log')(add_out)
return _impl return _impl
def _topk():
def _impl(inputs, attr, params):
k = int(params.pop(inputs.pop(1).name_hint).asnumpy())
if k < 1:
raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2')
if attr['sorted'] is False:
raise tvm.error.OpAttributeUnimplemented(
'Attribute sorted=False is not supported in operator TopKV2')
return AttrCvt(op_name='topk',
ignores=['sorted'],
extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr)
return _impl
def _logical(name): def _logical(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr) return AttrCvt(op_name=name)(inputs, attr)
...@@ -1271,6 +1285,7 @@ _convert_map = { ...@@ -1271,6 +1285,7 @@ _convert_map = {
'Sum' : _sum(), 'Sum' : _sum(),
'Tanh' : AttrCvt('tanh'), 'Tanh' : AttrCvt('tanh'),
'Tile' : _tile(), 'Tile' : _tile(),
'TopKV2' : _topk(),
'Transpose' : _transpose(), 'Transpose' : _transpose(),
'Unpack' : _unpack(), 'Unpack' : _unpack(),
......
...@@ -35,11 +35,31 @@ def compute_argsort(attrs, inputs, _, target): ...@@ -35,11 +35,31 @@ def compute_argsort(attrs, inputs, _, target):
"""Compute definition of argsort""" """Compute definition of argsort"""
axis = get_const_int(attrs.axis) axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend)) is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = str(attrs.dtype) dtype = attrs.dtype
return [ return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \
dtype=dtype, flag=False)
]
register_pattern("argsort", OpPattern.OPAQUE) register_pattern("argsort", OpPattern.OPAQUE)
@register_schedule("topk")
def schedule_topk(_, outs, target):
"""Schedule definition of argsort"""
with target:
return topi.generic.schedule_topk(outs)
@register_compute("topk")
def compute_topk(attrs, inputs, _, target):
"""Compute definition of argsort"""
k = get_const_int(attrs.k)
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype)
out = out if isinstance(out, list) else [out]
return out
register_pattern("topk", OpPattern.OPAQUE)
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
"""Classic algorithm operation""" """Classic algorithm operation"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import _make from . import _make
from ..expr import TupleWrapper
def argsort(data, axis=-1, is_ascend=1, dtype="float32"): def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
"""Performs sorting along the given axis and returns an array of indicies """Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order. having same shape as an input array that index data in sorted order.
...@@ -37,7 +38,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): ...@@ -37,7 +38,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
Whether to sort in ascending or descending order. Whether to sort in ascending or descending order.
dtype : string, optional dtype : string, optional
DType of the output indices. The data type of the output indices.
Returns Returns
------- -------
...@@ -45,3 +46,42 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): ...@@ -45,3 +46,42 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
Tensor with same shape as data. Tensor with same shape as data.
""" """
return _make.argsort(data, axis, is_ascend, dtype) return _make.argsort(data, axis, is_ascend, dtype)
def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
"""Get the top k elements in an input tensor along the given axis.
ret_type specifies the return type, can be one of ("both", "values", "indices").
Parameters
----------
data : relay.Expr
The input data tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : relay.Expr or List[relay.Expr]
The computed result.
"""
out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
if ret_type == "both":
return TupleWrapper(out, 2)
return out
...@@ -401,7 +401,7 @@ def upsampling(data, ...@@ -401,7 +401,7 @@ def upsampling(data,
with data of shape (n, c, h, w) with data of shape (n, c, h, w)
out will have a shape (n, c, h*scale, w*scale) out will have a shape (n, c, h*scale, w*scale)
method indicates the algorithm to be used while calculating ghe out value method indicates the algorithm to be used while calculating the out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR") and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR")
Parameters Parameters
......
...@@ -218,9 +218,9 @@ def take(data, indices, axis=None, mode="clip"): ...@@ -218,9 +218,9 @@ def take(data, indices, axis=None, mode="clip"):
the flattened input array is used. the flattened input array is used.
mode : str, optional mode : str, optional
Specifies how out-of-bound indices will behave. Specifies how out-of-bound indices will behave [clip, wrap].
clip - clip to the range (default) clip: clip to the range (default).
wrap - wrap around the indices wrap: wrap around the indices.
Returns Returns
------- -------
......
...@@ -83,7 +83,7 @@ Target CreateTarget(const std::string& target_name, ...@@ -83,7 +83,7 @@ Target CreateTarget(const std::string& target_name,
t->device_type = kDLGPU; t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImm::make("cuda")); t->keys_array.push_back(ir::StringImm::make("cuda"));
t->keys_array.push_back(ir::StringImm::make("gpu")); t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 512; t->max_num_threads = 1024;
t->thread_warp_size = 32; t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") { } else if (target_name == "rocm" || target_name == "opencl") {
// For now assume rocm schedule for opencl // For now assume rocm schedule for opencl
......
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2019 by Contributors
* \file nms.cc * \file argsort.cc
* \brief Non-maximum suppression operators * \brief Argsort operators
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h> #include <tvm/relay/attrs/algorithm.h>
...@@ -44,7 +44,6 @@ bool ArgsortRel(const Array<Type>& types, ...@@ -44,7 +44,6 @@ bool ArgsortRel(const Array<Type>& types,
<< types[0]; << types[0];
return false; return false;
} }
CHECK_EQ(param->dtype, Float(32));
reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype)); reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
return true; return true;
} }
...@@ -74,5 +73,6 @@ input array along the given axis. ...@@ -74,5 +73,6 @@ input array along the given axis.
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.set_support_level(6) .set_support_level(6)
.add_type_rel("Argsort", ArgsortRel); .add_type_rel("Argsort", ArgsortRel);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file topk.cc
* \brief TopK operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(TopKAttrs);
bool TopKRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
const TopKAttrs* param = attrs.as<TopKAttrs>();
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data);
int ndim = data->shape.size();
int axis = param->axis;
if (axis < 0) {
axis += ndim;
}
CHECK(axis >= 0 && axis < ndim);
Array<IndexExpr> out_shape;
for (int i = 0; i < ndim; ++i) {
if (i != axis || param->k < 1) {
out_shape.push_back(data->shape[i]);
} else {
out_shape.push_back(param->k);
}
}
auto values_ty = TensorTypeNode::make(out_shape, data->dtype);
auto indices_ty = TensorTypeNode::make(out_shape, param->dtype);
if (param->ret_type == "both") {
reporter->Assign(types[1], TupleTypeNode::make({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
reporter->Assign(types[1], values_ty);
} else if (param->ret_type == "indices") {
reporter->Assign(types[1], indices_ty);
} else {
LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
}
return true;
}
Expr MakeTopK(Expr data,
int k,
int axis,
std::string ret_type,
bool is_ascend,
DataType dtype) {
auto attrs = make_node<TopKAttrs>();
attrs->k = k;
attrs->axis = axis;
attrs->ret_type = ret_type;
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("topk");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.topk")
.set_body_typed(MakeTopK);
RELAY_REGISTER_OP("topk")
.describe(R"doc(Get the top k elements in an input tensor along the given axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.TopKAttrs")
.add_argument("data", "Tensor", "Input data.")
.set_support_level(6)
.add_type_rel("TopK", TopKRel);
} // namespace relay
} // namespace tvm
...@@ -608,6 +608,45 @@ def test_forward_Crop(): ...@@ -608,6 +608,45 @@ def test_forward_Crop():
verify((5, 32, 40, 40), (5, 32, 25, 25)) verify((5, 32, 40, 40), (5, 32, 25, 25))
verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5)) verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5))
def test_forward_argsort():
def verify(shape, axis, is_ascend, dtype="float32"):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.argsort(mx.nd.array(x_np), axis=axis, is_ascend=is_ascend, dtype=dtype)
mx_sym = mx.sym.argsort(mx.sym.var("x"), axis=axis, is_ascend=is_ascend, dtype=dtype)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 3, 4), axis=0, is_ascend=False)
verify((1, 4, 6), axis=1, is_ascend=True)
verify((3, 5, 6), axis=-3, is_ascend=False, dtype="int32")
def test_forward_topk():
def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type,
is_ascend=is_ascend, dtype=dtype)
mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type,
is_ascend=is_ascend, dtype=dtype)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
if isinstance(ref_res, list):
assert len(op_res) == len(ref_res)
for i, t in enumerate(op_res):
tvm.testing.assert_allclose(t.asnumpy(), ref_res[i].asnumpy())
else:
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((3, 4), k=1, axis=0, ret_type="both")
verify((3, 4), k=1, axis=-1, ret_type="indices")
verify((3, 5, 6), k=2, axis=2, ret_type="value")
verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True)
verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
...@@ -650,3 +689,5 @@ if __name__ == '__main__': ...@@ -650,3 +689,5 @@ if __name__ == '__main__':
test_forward_bilinear_resize() test_forward_bilinear_resize()
test_forward_rnn_layer() test_forward_rnn_layer()
test_forward_Crop() test_forward_Crop()
test_forward_argsort()
test_forward_topk()
...@@ -754,6 +754,24 @@ def test_forward_split(): ...@@ -754,6 +754,24 @@ def test_forward_split():
_test_split((3, 6, 4), -2, [1, 4, 1], 'float32') _test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
######################################################################
# TopKV2
# ------
def _test_forward_top_k_v2(in_shape, k):
np_data = np.random.uniform(-100, 100, size=in_shape).astype("float32")
tf.reset_default_graph()
in_data = tf.placeholder("float32", in_shape, name="in_data")
tf.math.top_k(in_data, k, name='TopK')
compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')
def test_forward_top_k_v2():
_test_forward_top_k_v2((3,), 1)
_test_forward_top_k_v2((3,), 3)
_test_forward_top_k_v2((3, 5, 7), 3)
_test_forward_top_k_v2((3, 5, 7), 3)
####################################################################### #######################################################################
# Unstack # Unstack
# ------- # -------
...@@ -1704,6 +1722,7 @@ if __name__ == '__main__': ...@@ -1704,6 +1722,7 @@ if __name__ == '__main__':
test_forward_split() test_forward_split()
test_forward_unstack() test_forward_unstack()
test_forward_tile() test_forward_tile()
test_forward_top_k_v2()
# Activations # Activations
test_forward_sigmoid() test_forward_sigmoid()
......
...@@ -16,18 +16,15 @@ ...@@ -16,18 +16,15 @@
# under the License. # under the License.
""" Support level6 operator test cases. """ Support level6 operator test cases.
""" """
import math
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list
import topi.testing
def test_argsort(): def test_argsort():
def verify_argsort(shape, axis, is_ascend): def verify_argsort(shape, axis, is_ascend, dtype):
x = relay.var("x", relay.TensorType(shape, "float32")) x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.argsort(x, axis=axis, is_ascend=is_ascend) z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype)
zz = relay.ir_pass.infer_type(z)
func = relay.Function([x], z) func = relay.Function([x], z)
x_data = np.random.uniform(size=shape).astype("float32") x_data = np.random.uniform(size=shape).astype("float32")
if is_ascend: if is_ascend:
...@@ -39,11 +36,58 @@ def test_argsort(): ...@@ -39,11 +36,58 @@ def test_argsort():
for kind in ["graph", "debug"]: for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target) intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data) op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype(dtype), rtol=1e-5)
verify_argsort((2, 3, 4), axis=0, is_ascend=False) for dtype in ["int32", "int64", "float32", "float64"]:
verify_argsort((1, 4, 6), axis=1, is_ascend=True) verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype)
verify_argsort((3, 5, 6), axis=-1, is_ascend=False) verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype)
verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype)
def test_topk():
def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100)
x = relay.var("x", relay.TensorType(shape, "float32"))
out = relay.topk(x, k, axis, ret_type, is_ascend, dtype)
if isinstance(out, relay.expr.TupleWrapper):
out = out.astuple()
func = relay.Function([x], out)
np_data = np.random.uniform(size=shape).astype("float32")
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
else:
np_indices = np.argsort(-np_data, axis=axis)
kk = k if k >= 1 else shape[axis]
if axis == 0:
np_indices = np_indices[:kk, :]
np_values = np.zeros(np_indices.shape).astype("float32")
for i in range(shape[1]):
np_values[:, i] = np_data[np_indices[:, i], i]
else:
np_indices = np_indices[:, :kk]
np_values = np.zeros(np_indices.shape).astype("float32")
for i in range(shape[0]):
np_values[i, :] = np_data[i, np_indices[i, :]]
np_indices = np_indices.astype(dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(np_data)
if ret_type == "both":
tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values)
tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices)
elif ret_type == "values":
tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
else:
tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
for k in [0, 1, 5]:
for axis in [0, -1, 1]:
for ret_type in ["both", "values", "indices"]:
for dtype in ["int64", "float32"]:
verify_topk(k, axis, ret_type, False, dtype)
verify_topk(k, axis, ret_type, True, dtype)
if __name__ == "__main__": if __name__ == "__main__":
test_argsort() test_argsort()
test_topk()
...@@ -21,3 +21,4 @@ from . import ssd ...@@ -21,3 +21,4 @@ from . import ssd
from .ssd import * from .ssd import *
from .nms import * from .nms import *
from .rcnn import * from .rcnn import *
from .sort import *
...@@ -732,7 +732,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, ...@@ -732,7 +732,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
score_axis = score_index score_axis = score_index
score_shape = (batch_size, num_anchors) score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8) "sort_tensor_buf", data_alignment=8)
......
...@@ -36,3 +36,20 @@ def schedule_argsort(outs): ...@@ -36,3 +36,20 @@ def schedule_argsort(outs):
The computation schedule for the op. The computation schedule for the op.
""" """
return _default_schedule(outs, False) return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_topk(outs):
"""Schedule for topk operator.
Parameters
----------
outs: Array of Tensor
The indices that would sort an input array along
the given axis.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
...@@ -18,9 +18,10 @@ ...@@ -18,9 +18,10 @@
"""Argsort operator""" """Argsort operator"""
import tvm import tvm
from tvm import api from tvm import api
from .util import get_const_tuple
@tvm.target.generic_func @tvm.target.generic_func
def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array """Performs sorting along the given axis and returns an array
of indices having the same shape as an input array that index of indices having the same shape as an input array that index
data in sorted order. data in sorted order.
...@@ -30,22 +31,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): ...@@ -30,22 +31,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
data : tvm.Tensor data : tvm.Tensor
The input tensor. The input tensor.
valid_count : tvm.Tensor valid_count : tvm.Tensor, optional
1-D tensor for valid number of boxes only for ssd. 1-D tensor for valid number of boxes only for ssd.
axis : optional, int axis : int, optional
Axis along which to sort the input tensor. Axis along which to sort the input tensor.
By default the flattened array is used. By default the flattened array is used.
is_ascend : optional, boolean is_ascend : boolean, optional
Whether to sort in ascending or descending order. Whether to sort in ascending or descending order.
dtype : optional, string dtype : string, optional
DType of the output indices. DType of the output indices.
flag : optional, boolean
Whether valid_count is valid.
Returns Returns
------- -------
out : tvm.Tensor out : tvm.Tensor
...@@ -58,23 +56,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): ...@@ -58,23 +56,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
# An example to use argsort # An example to use argsort
dshape = (1, 5, 6) dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data") data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
axis = 0 axis = 0
is_ascend = False is_ascend = False
flag = False out = argsort(data, axis=axis, is_ascend=is_ascend)
out = argsort(data, valid_count, axis, is_ascend, flag)
np_data = np.random.uniform(dshape) np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_argsort(out) s = topi.generic.schedule_argsort(out)
f = tvm.build(s, [data, valid_count, out], "llvm") f = tvm.build(s, [data, out], "llvm")
ctx = tvm.cpu() ctx = tvm.cpu()
tvm_data = tvm.nd.array(np_data, ctx) tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f(tvm_data, tvm_valid_count, tvm_out) f(tvm_data, tvm_out)
""" """
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
if flag: if valid_count is not None:
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4) "valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8)
...@@ -103,3 +97,58 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): ...@@ -103,3 +97,58 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
name="argsort_cpu", name="argsort_cpu",
tag="argsort_cpu") tag="argsort_cpu")
return out return out
@tvm.target.generic_func
def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.
Parameters
----------
data : tvm.Tensor
The input tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : tvm.Tensor or List[tvm.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_shape = list(get_const_tuple(data.shape))
if k >= 1:
out_shape[axis] = k
out_bufs = []
if ret_type in ["both", "values"]:
out_bufs.append(api.decl_buffer(out_shape, data.dtype, "value_buf", data_alignment=8))
if ret_type in ["both", "indices"]:
out_bufs.append(api.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8))
out_shapes = [out_shape] * len(out_bufs)
out = tvm.extern(out_shapes,
[data],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.topk", ins[0], *outs, k, axis, ret_type, is_ascend),
in_buffers=[data_buf],
out_buffers=out_bufs,
name="topk_cpu",
tag="topk_cpu")
return out
...@@ -151,6 +151,8 @@ def strided_slice(a, begin, end, strides=None): ...@@ -151,6 +151,8 @@ def strided_slice(a, begin, end, strides=None):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
if strides is None:
strides = []
return cpp.strided_slice(a, begin, end, strides) return cpp.strided_slice(a, begin, end, strides)
......
...@@ -331,7 +331,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1, ...@@ -331,7 +331,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
score_axis = score_index score_axis = score_index
score_shape = (batch_size, num_anchors) score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
out, box_indices = hybrid_nms(data, sort_tensor, valid_count, out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
tvm.const(max_output_size, dtype="int32"), tvm.const(max_output_size, dtype="int32"),
tvm.const(iou_threshold, dtype="float32"), tvm.const(iou_threshold, dtype="float32"),
......
...@@ -16,23 +16,15 @@ ...@@ -16,23 +16,15 @@
# under the License. # under the License.
"""Test code for vision package""" """Test code for vision package"""
from __future__ import print_function from __future__ import print_function
import math
import numpy as np import numpy as np
import tvm import tvm
import topi import topi
import topi.testing import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from topi import argsort
def test_argsort(): def test_argsort():
dshape = (1, 8) dshape = (20, 100)
valid_count_shape = (2,)
data = tvm.placeholder(dshape, name="data", dtype="float32") data = tvm.placeholder(dshape, name="data", dtype="float32")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype) np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.argsort(-np_data) np_result = np.argsort(-np_data)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -41,19 +33,77 @@ def test_argsort(): ...@@ -41,19 +33,77 @@ def test_argsort():
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False) out = topi.argsort(data, axis=-1, is_ascend=False)
s = topi.generic.schedule_argsort(out) s = topi.generic.schedule_argsort(out)
tvm_data = tvm.nd.array(np_data, ctx) tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx)
f = tvm.build(s, [data, valid_count, out], device) f = tvm.build(s, [data, out], device)
f(tvm_data, tvm_valid_count, tvm_out) f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0)
for device in ['llvm', 'cuda', 'opencl']: for device in ['llvm', 'cuda', 'opencl']:
check_device(device) check_device(device)
def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100)
data_dtype = "float32"
data = tvm.placeholder(shape, name="data", dtype=data_dtype)
np_data = np.random.uniform(size=shape).astype(data_dtype)
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
else:
np_indices = np.argsort(-np_data, axis=axis)
kk = k if k >= 1 else shape[axis]
if axis == 0:
np_indices = np_indices[:kk, :]
np_values = np.zeros(np_indices.shape).astype(data_dtype)
for i in range(shape[1]):
np_values[:, i] = np_data[np_indices[:, i], i]
else:
np_indices = np_indices[:, :kk]
np_values = np.zeros(np_indices.shape).astype(data_dtype)
for i in range(shape[0]):
np_values[i, :] = np_data[i, np_indices[i, :]]
np_indices = np_indices.astype(dtype)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
outs = topi.topk(data, k, axis, ret_type, is_ascend, dtype)
outs = outs if isinstance(outs, list) else [outs]
s = topi.generic.schedule_topk(outs)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_res = []
for t in outs:
tvm_res.append(tvm.nd.empty(t.shape, dtype=t.dtype, ctx=ctx))
f = tvm.build(s, [data] + outs, device)
f(tvm_data, *tvm_res)
if ret_type == "both":
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values)
tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_indices)
elif ret_type == "values":
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values)
else:
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices)
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
def test_topk():
for k in [0, 1, 5]:
for axis in [0, -1, 1]:
for ret_type in ["both", "values", "indices"]:
for dtype in ["int64", "float32"]:
verify_topk(k, axis, ret_type, True, dtype)
verify_topk(k, axis, ret_type, False, dtype)
if __name__ == "__main__": if __name__ == "__main__":
test_argsort() test_argsort()
test_topk()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment