Unverified Commit 430cb899 by Wang Yucheng Committed by GitHub

[Torch] Add support for split (#5174)

* [Torch] Add support for split

* fix

* fix test class
parent c97e41b0
......@@ -105,6 +105,36 @@ def _slice():
return _op.transform.strided_slice(data, begin, end, strides)
return _impl
def _split():
def _impl(inputs, input_types):
data = inputs[0]
split_size = int(inputs[1])
dim = int(inputs[2])
split_index = split_size
indices = []
while split_index < _infer_shape(data)[dim]:
indices.append(split_index)
split_index += split_size
return _op.split(data, indices, dim)
return _impl
def _split_with_sizes():
def _impl(inputs, inputs_types):
data = inputs[0]
dim = int(inputs[2])
split_index = 0
indices = []
sections = _infer_shape(inputs[1])
for i in range(len(sections) - 1):
split_index += sections[i]
indices.append(split_index)
return _op.split(data, indices, dim)
return _impl
def _select():
def _impl(inputs, input_types):
data = inputs[0]
......@@ -886,6 +916,8 @@ _convert_map = {
"aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(),
"aten::slice" : _slice(),
"aten::split" : _split(),
"aten::split_with_sizes" : _split_with_sizes(),
"aten::select" : _select(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
......@@ -1415,6 +1447,10 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
if isinstance(ret[0], list):
ret[0] = _expr.Tuple(ret[0])
func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
return _module.IRModule.from_expr(func), tvm_params
......@@ -379,6 +379,29 @@ def test_forward_maxpool1d():
stride=2).eval(),
input_data)
def test_forward_split():
torch.set_grad_enabled(False)
input_shape = [4, 10]
class Split(Module):
def __init__(self, split_size_or_sections, dim):
super(Split, self).__init__()
self.split_size_or_sections = split_size_or_sections
self.dim = dim
def forward(self, *args):
return torch.split(args[0], self.split_size_or_sections, self.dim)
input_data = torch.rand(input_shape).float()
verify_model(Split(2, 0).float().eval(),
input_data=input_data)
verify_model(Split(3, 1).float().eval(),
input_data=input_data)
verify_model(Split(4, 1).float().eval(),
input_data=input_data)
verify_model(Split([2, 3, 5], 1).float().eval(),
input_data=input_data)
def test_forward_avgpool():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
......@@ -1077,6 +1100,7 @@ if __name__ == "__main__":
test_forward_expand()
test_forward_pow()
test_forward_chunk()
test_forward_split()
test_upsample()
test_to()
test_adaptive_pool3d()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment