Unverified Commit b637840b by Matthew Brookhart Committed by GitHub

Add TopK to ONNX Frontend (#5441)

* Add TopK to ONNX Frontend

* respond to review comments
parent 2dbe6261
......@@ -1470,6 +1470,22 @@ class NonZero(OnnxOpConverter):
output = AttrCvt(op_name='argwhere')(inputs, attr, params)
return _op.transpose(output, axes=(1, 0))
class TopK(OnnxOpConverter):
"""Operator converter for TopK
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if len(inputs) != 2:
raise ValueError("Expect 2 input only")
axis = attr.get("axis", -1)
largest = attr.get("largest", 1)
if largest == 0:
raise ValueError("TVM only supports finding TopK largest elements")
K = int(infer_value(inputs[1], params).asnumpy()[0])
return _op.topk(inputs[0], k=K, axis=axis)
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -1573,8 +1589,11 @@ def _get_convert_map(opset):
'ReduceProd': ReduceProd.get_converter(opset),
# 'ReduceProd'
# 'ReduceLogSumExp'
#defs/sorting
'ArgMax': ArgMax.get_converter(opset),
'ArgMin': ArgMin.get_converter(opset),
'TopK': TopK.get_converter(opset),
# defs/tensor
'Cast': Cast.get_converter(opset),
......
......@@ -2330,6 +2330,43 @@ def test_nonzero():
result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 2, 2], [0, 1, 0, 1]]
verify_nonzero(input_data, result, dtype=np.int64)
def test_topk():
def verify_topk(input_dims, K, axis=-1):
output_dims = list(input_dims)
output_dims[axis] = K
node = helper.make_node('TopK',
inputs=['X', 'K'],
outputs=['Values', 'Indicies'],
axis=axis)
graph = helper.make_graph([node],
"topk_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
helper.make_tensor_value_info("K", TensorProto.INT64, [1,])],
initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])],
outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims),
helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)])
model = helper.make_model(graph, producer_name='topk_test')
indata = np.random.uniform(-10, 10, input_dims).astype(np.float32)
onnx_out = get_onnxruntime_output(model, [indata, k])
for target, ctx in [('llvm', tvm.cpu())]:
tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims],
output_dtype=['float32', 'int64'])
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)
for n in [12, 32]:
for shape in [[n], [n, n], [n, n, n]]:
for k in [1, 5, 10]:
verify_topk(shape, k)
verify_topk([n, n, n], 5, 0)
verify_topk([n, n, n], 5, 1)
verify_topk([n, n, n], 5, 2)
if __name__ == '__main__':
test_flatten()
......@@ -2392,3 +2429,4 @@ if __name__ == '__main__':
test_lstm()
test_resize()
test_nonzero()
test_topk()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment