Commit 292ab02d by Bob.Liu Committed by Tianqi Chen

[FRONTEND][ONNX] fixed operator converter for Split in onnx frontend (#2038)

parent ca0fe22c
......@@ -464,6 +464,23 @@ class Unsqueeze(OnnxOpConverter):
inputs[0] = _sym.expand_dims(inputs[0], axis=axes, num_newaxis=1)
return inputs[0]
class Split(OnnxOpConverter):
""" Operator converter for Split.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
attr['indices_or_sections'] = []
index = 0
for i in attr['split'][:-1]:
index += i
attr['indices_or_sections'].append(index)
return AttrCvt(
op_name='split',
ignores=['split'])(inputs, attr, params)
class Slice(OnnxOpConverter):
""" Operator converter for Slice.
"""
......@@ -754,7 +771,7 @@ def _get_convert_map(opset):
'Cast': Cast.get_converter(opset),
'Reshape': Reshape.get_converter(opset),
'Concat': Renamer('concatenate'),
'Split': AttrCvt('split', {'split': 'indices_or_sections'}),
'Split': Split.get_converter(opset),
'Slice': Slice.get_converter(opset),
'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
'Gather': Gather.get_converter(opset),
......
......@@ -712,6 +712,41 @@ def test_constantfill():
verify_constantfill(False, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6))
def verify_split(indata, outdatas, split, axis=0):
indata = np.array(indata).astype(np.float32)
outdatas = [np.array(o).astype(np.float32) for o in outdatas]
node = helper.make_node(
'Split',
inputs=['input'],
outputs=['output_{}'.format(i) for i in range(len(split))],
axis=axis,
split=split
)
graph = helper.make_graph([node],
'split_test',
inputs = [helper.make_tensor_value_info("input",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("output_{}".format(i),
TensorProto.FLOAT, list(outdatas[i].shape))
for i in range(len(split))
])
model = helper.make_model(graph, producer_name='split_test')
for target, ctx in ctx_list():
output_shape = [o.shape for o in outdatas]
output_type = ['float32', 'float32', 'float32']
tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type)
for o, t in zip(outdatas, tvm_out):
tvm.testing.assert_allclose(o, t)
def test_split():
# 1D
verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0)
verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0)
# 2D
verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]],
[[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1)
if __name__ == '__main__':
# verify_super_resolution_example()
# verify_squeezenet1_1()
......@@ -737,3 +772,4 @@ if __name__ == '__main__':
test_forward_arg_min_max()
test_softmax()
test_constantfill()
test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment