Unverified Commit a60de368 by Samuel Committed by GitHub

[ONNX]GatherNd, Round, IsNaN, IsInf (#5445)

parent 37e57548
...@@ -942,6 +942,14 @@ class Gather(OnnxOpConverter): ...@@ -942,6 +942,14 @@ class Gather(OnnxOpConverter):
extras={'axis': axis})(inputs, {}) extras={'axis': axis})(inputs, {})
class GatherND(OnnxOpConverter):
""" Operator converter for GatherND.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.gather_nd(inputs[0], inputs[1])
class Greater(OnnxOpConverter): class Greater(OnnxOpConverter):
""" Operator logical greater. """ Operator logical greater.
""" """
...@@ -1536,6 +1544,9 @@ def _get_convert_map(opset): ...@@ -1536,6 +1544,9 @@ def _get_convert_map(opset):
'Reciprocal': Reciprocal.get_converter(opset), 'Reciprocal': Reciprocal.get_converter(opset),
'Floor': Renamer('floor'), 'Floor': Renamer('floor'),
'Ceil': Renamer('ceil'), 'Ceil': Renamer('ceil'),
'Round': Renamer('round'),
'IsInf': Renamer('isinf'),
'IsNaN': Renamer('isnan'),
'Sqrt': Renamer('sqrt'), 'Sqrt': Renamer('sqrt'),
'Relu': Renamer('relu'), 'Relu': Renamer('relu'),
'LeakyRelu': Renamer('leaky_relu'), 'LeakyRelu': Renamer('leaky_relu'),
...@@ -1606,6 +1617,7 @@ def _get_convert_map(opset): ...@@ -1606,6 +1617,7 @@ def _get_convert_map(opset):
'DepthToSpace': DepthToSpace.get_converter(opset), 'DepthToSpace': DepthToSpace.get_converter(opset),
'SpaceToDepth': SpaceToDepth.get_converter(opset), 'SpaceToDepth': SpaceToDepth.get_converter(opset),
'Gather': Gather.get_converter(opset), 'Gather': Gather.get_converter(opset),
'GatherND': GatherND.get_converter(opset),
'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}), 'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
'Unsqueeze': Unsqueeze.get_converter(opset), 'Unsqueeze': Unsqueeze.get_converter(opset),
'Pad': Pad.get_converter(opset), 'Pad': Pad.get_converter(opset),
......
...@@ -542,6 +542,70 @@ def test_clip(): ...@@ -542,6 +542,70 @@ def test_clip():
{'min': -1.0, 'max': 1.0}) {'min': -1.0, 'max': 1.0})
def test_round():
_test_onnx_op_elementwise((2, 4, 5, 6), np.round, {}, 'float32', 'Round', {})
def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs):
indata = np.random.choice(a=[np.nan, np.inf, -np.inf, 0.5, 1.0, 0], size=inshape).astype(dtype)
outdata = outfunc(indata, **npargs)
y = helper.make_node(opname, ['in'], ['out'], **kwargs)
graph = helper.make_graph([y],
opname+'_test',
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(indata.shape))],
outputs=[helper.make_tensor_value_info("out",
TensorProto.BOOL, list(outdata.shape))])
model = helper.make_model(graph, producer_name=opname+'_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(
model, indata, target, ctx, outdata.shape, dtype)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_isinf():
_test_finite_ops((2, 4, 5, 6), np.isinf, {}, 'float32', 'IsInf', {})
def test_isnan():
_test_finite_ops((2, 4, 5, 6), np.isnan, {}, 'float32', 'IsNaN', {})
def verify_gather_nd(in_shape, indices, dtype):
x = np.random.uniform(size=in_shape).astype(dtype)
indices = np.array(indices, dtype="int32")
out_np = topi.testing.gather_nd_python(x, indices)
y = helper.make_node("GatherND", ['in', 'indices'], ['out'])
graph = helper.make_graph([y],
'gather_test',
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("indices",
TensorProto.INT32, list(indices.shape))],
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_np.shape))])
model = helper.make_model(graph, producer_name='gather_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(
model, [x, indices], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out)
def test_gather_nd():
verify_gather_nd((2, 2), [[0,0],[1,1]], 'int32')
verify_gather_nd((3, 3, 3), [[0,1],[1,0]] , 'float32')
verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], 'float32')
def test_onehot(): def test_onehot():
indices_shape = [10] indices_shape = [10]
indices_array = np.random.randint( indices_array = np.random.randint(
...@@ -2379,11 +2443,15 @@ if __name__ == '__main__': ...@@ -2379,11 +2443,15 @@ if __name__ == '__main__':
test_slice() test_slice()
test_floor() test_floor()
test_ceil() test_ceil()
test_round()
test_isinf()
test_isnan()
test_clip() test_clip()
test_onehot() test_onehot()
test_matmul() test_matmul()
test_batch_matmul() test_batch_matmul()
test_gather() test_gather()
test_gather_nd()
test_lrn() test_lrn()
test_instance_norm() test_instance_norm()
test_upsample() test_upsample()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment