Unverified Commit e1ebf062 by Neo Chien Committed by GitHub

[Relay][Frontend][ONNX] operator support NonZero (#5073)

* [Relay][Frontend][ONNX] operator support: NonZero

* update

* Solve the build fail

* solve the build fail

* Replace ctx_list with tvm.cpu()
parent 38118bef
......@@ -1444,6 +1444,18 @@ class Resize(OnnxOpConverter):
return _op.image.resize(inputs[0], out_size, layout, method, coord_trans)
class NonZero(OnnxOpConverter):
"""Operator converter for NonZero
"""
@classmethod
def _impl_v9(cls, inputs, attr, params):
if len(inputs) > 1:
raise ValueError("Expect 1 input only")
output = AttrCvt(op_name='argwhere')(inputs, attr, params)
return _op.transpose(output, axes=(1, 0))
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -1573,6 +1585,7 @@ def _get_convert_map(opset):
'Where': Where.get_converter(opset),
'Or': Or.get_converter(opset),
'Resize': Resize.get_converter(opset),
'NonZero': NonZero.get_converter(opset),
}
......
......@@ -30,21 +30,38 @@ from tvm.relay.testing.config import ctx_list
import scipy
def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
""" Generic function to execute and get tvm output"""
target = 'llvm'
def get_input_data_shape_dict(graph_def, input_data):
if isinstance(input_data, list):
input_names = {}
shape_dict = {}
dtype_dict = {}
for i, _ in enumerate(input_data):
input_names[i] = graph_def.graph.input[i].name
shape_dict[input_names[i]] = input_data[i].shape
dtype_dict[input_names[i]] = input_data[i].dtype
else:
input_names = graph_def.graph.input[0].name
shape_dict = {input_names: input_data.shape}
dtype_dict = {input_names: input_data.dtype}
return input_names, shape_dict
def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None):
""" Generic function to execute and get tvm output with vm executor"""
_, shape_dict = get_input_data_shape_dict(graph_def, input_data)
mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
ex = relay.create_executor('vm', mod=mod, ctx=ctx, target=target)
indata = tvm.nd.array(input_data)
result = ex.evaluate()(indata)
return result.asnumpy()
def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
""" Generic function to execute and get tvm output"""
target = 'llvm'
input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data)
mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
with relay.build_config(opt_level=1):
......@@ -2209,6 +2226,35 @@ def test_resize():
verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel")
def test_nonzero():
def verify_nonzero(indata, outdata, dtype):
node = helper.make_node('NonZero',
inputs=['X'],
outputs=['Y'],)
graph = helper.make_graph([node],
"nonzero_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.INT64, list(indata.shape))],
outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, list(outdata.shape))])
model = helper.make_model(graph, producer_name='nonzero_test')
onnx_out = get_onnxruntime_output(model, indata, dtype)
for target, ctx in [('llvm', tvm.cpu())]:
tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=9)
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)
input_data = np.array([[1, 0], [1, 1]], dtype=np.int64)
result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 1], [0, 0, 1]]
verify_nonzero(input_data, result, dtype=np.int64)
input_data = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]], dtype=np.int64)
result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 2, 2], [0, 1, 0, 1]]
verify_nonzero(input_data, result, dtype=np.int64)
if __name__ == '__main__':
test_flatten()
test_reshape()
......@@ -2269,3 +2315,4 @@ if __name__ == '__main__':
test_pooling()
test_lstm()
test_resize()
test_nonzero()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment