Commit 8a2f10e0 by Neo Chien Committed by Tianqi Chen

[Relay][Frontend][ONNX] operator support: Tile (#3941)

* [Relay][Frontend][ONNX] operator support: Tile

* Trigger notification
parent d7a09150
...@@ -885,6 +885,18 @@ class And(Elemwise): ...@@ -885,6 +885,18 @@ class And(Elemwise):
return _op.logical_and(inputs[0], inputs[1]) return _op.logical_and(inputs[0], inputs[1])
class Tile(Elemwise):
"""Operator converter for Tile
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if 'repeats' not in attr:
raise tvm.error.OpAttributeInvalid('Attribute "repeats" should be set '
'for operator Tile.')
reps = attr.pop('repeats') # The number of times repeating the tensor data.
return _op.tile(inputs[0], reps)
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -1002,7 +1014,8 @@ def _get_convert_map(opset): ...@@ -1002,7 +1014,8 @@ def _get_convert_map(opset):
'Sign': Sign.get_converter(opset), 'Sign': Sign.get_converter(opset),
'Equal': Equal.get_converter(opset), 'Equal': Equal.get_converter(opset),
'Not': Not.get_converter(opset), 'Not': Not.get_converter(opset),
'And': And.get_converter(opset) 'And': And.get_converter(opset),
'Tile': Tile.get_converter(opset)
} }
......
...@@ -1205,6 +1205,27 @@ def test_and(): ...@@ -1205,6 +1205,27 @@ def test_and():
verify_and(indata=[x, y], dtype=bool) verify_and(indata=[x, y], dtype=bool)
def verify_tile(indata, outdata, **kwargs):
node = helper.make_node('Tile', inputs=['in'], outputs=['out'], **kwargs)
graph = helper.make_graph([node],
'tile_test',
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='tile_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_tile():
x = np.random.rand(2, 3, 4, 5).astype(np.float32)
repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
z = np.tile(x, repeats)
verify_tile(x, z, repeats=repeats)
if __name__ == '__main__': if __name__ == '__main__':
test_flatten() test_flatten()
test_reshape() test_reshape()
...@@ -1250,3 +1271,4 @@ if __name__ == '__main__': ...@@ -1250,3 +1271,4 @@ if __name__ == '__main__':
test_sign() test_sign()
test_not() test_not()
test_and() test_and()
test_tile()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment