Commit 211ab978 by ANSHUMAN TRIPATHY Committed by Tianqi Chen

Transpose core dump resolved (#1355)

parent 13362e12
...@@ -89,6 +89,18 @@ inline Tensor transpose(const Tensor& x, ...@@ -89,6 +89,18 @@ inline Tensor transpose(const Tensor& x,
} }
auto axes_val = GetConstIntValues(axes, "axes"); auto axes_val = GetConstIntValues(axes, "axes");
for (size_t i = 0; i < axes_val.size(); ++i) {
int axis = axes_val[i];
if (axes_val[i] < 0) {
axes_val[i] = static_cast<int>(x->shape.size()) + axes_val[i];
}
CHECK((0 <= axes_val[i]) && (axes_val[i] < static_cast<int>(x->shape.size())))
<< "axis=" << axis << " is invalid for the "
<< static_cast<int>(x->shape.size()) << "-dimensional input tensor";
CHECK(1 == std::count(std::begin(axes_val), std::end(axes_val), axes_val[i]))
<< "repeated axis in transpose";
}
Array<Expr> new_shape; Array<Expr> new_shape;
for (size_t i = 0; i < axes_val.size(); ++i) { for (size_t i = 0; i < axes_val.size(); ++i) {
......
...@@ -281,6 +281,7 @@ def test_tranpose(): ...@@ -281,6 +281,7 @@ def test_tranpose():
verify_tranpose((3, 10, 2), (1, 0, 2)) verify_tranpose((3, 10, 2), (1, 0, 2))
verify_tranpose((3, 10, 5), (2, 0, 1)) verify_tranpose((3, 10, 5), (2, 0, 1))
verify_tranpose((3, 10), None) verify_tranpose((3, 10), None)
verify_tranpose((3, 10, 5), (2, -3, 1))
def test_reshape(): def test_reshape():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment