Commit 81b42bc3 by Pariksheet Pinjari Committed by Yizhi Liu

Split_indices negative axis added (#1595)

parent 60769b77
......@@ -475,6 +475,11 @@ inline Array<Tensor> split_sections(const Tensor& x,
int axis,
std::string name = "tensor",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
}
CHECK_LT(axis, x->shape.size()) << "axis out of bounds";
auto src_axis_size = static_cast<int>(GetConstInt(x->shape[axis]));
CHECK_GT(num_sections, 0) << "Slice count must be > 0";
......
......@@ -340,6 +340,7 @@ def test_concatenate():
def test_split():
verify_split((2, 12, 3), 3, 1)
verify_split((2, 12, 3), 3, -1)
verify_split((2, 12, 3), [2, 4], 1)
verify_split((10, 12, 24), [5, 7, 9], -1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment