Commit 510cd5ec by Siva Committed by Tianqi Chen

Squeeze bug fix. (#506)

parent 69d5fcab
...@@ -638,11 +638,14 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs, ...@@ -638,11 +638,14 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
} else { } else {
std::unordered_set<dim_t> axis_checker; std::unordered_set<dim_t> axis_checker;
for (size_t i = 0; i < param.axis.ndim(); ++i) { for (size_t i = 0; i < param.axis.ndim(); ++i) {
int real_axis;
if (param.axis[i] < 0) { if (param.axis[i] < 0) {
int real_axis = param.axis[i] + static_cast<int>(shp.ndim()); real_axis = param.axis[i] + static_cast<int>(shp.ndim());
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0); } else {
axis_checker.insert(real_axis); real_axis = param.axis[i];
} }
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
axis_checker.insert(real_axis);
} }
for (size_t i = 0; i < shp.ndim(); ++i) { for (size_t i = 0; i < shp.ndim(); ++i) {
if (axis_checker.find(i) == axis_checker.end()) { if (axis_checker.find(i) == axis_checker.end()) {
......
...@@ -116,6 +116,24 @@ def test_flatten(): ...@@ -116,6 +116,24 @@ def test_flatten():
sdict = infer_shape(y) sdict = infer_shape(y)
assert(sdict["y"][0] == [10, 200]) assert(sdict["y"][0] == [10, 200])
def test_squeeze():
x = sym.Variable("x", shape=(1, 1, 1, 10))
y = sym.squeeze(x, axis=(1,2), name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [1, 10])
x = sym.Variable("x", shape=(1, 3, 1))
y = sym.squeeze(x, name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [3])
y = sym.squeeze(x, axis=(0), name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [3, 1])
y = sym.squeeze(x, axis=(0,2), name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [3])
# Level 2 # Level 2
def test_conv2d(): def test_conv2d():
...@@ -331,3 +349,4 @@ if __name__ == "__main__": ...@@ -331,3 +349,4 @@ if __name__ == "__main__":
test_reduce() test_reduce()
test_transpose() test_transpose()
test_prelu() test_prelu()
test_squeeze()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment