Commit 7ce6a41d by Huilin Qu Committed by Tianqi Chen

Fix gather_nd in Relay (#3442)

* Fix gather_nd in Relay

* Add test cases for gather_nd.
parent 05c528c1
...@@ -2332,7 +2332,7 @@ bool GatherNDRel(const Array<Type>& types, ...@@ -2332,7 +2332,7 @@ bool GatherNDRel(const Array<Type>& types,
return false; return false;
} }
const size_t ndim = data->shape.size(); const size_t ndim = data->shape.size();
const IntImm* mdim = data->shape[0].as<IntImm>(); const IntImm* mdim = indices->shape[0].as<IntImm>();
const size_t kdim = indices->shape.size() - 1; const size_t kdim = indices->shape.size() - 1;
CHECK(size_t(mdim->value) <= ndim) CHECK(size_t(mdim->value) <= ndim)
<< "GatherND: indices shape does satisfy."; << "GatherND: indices shape does satisfy.";
......
...@@ -528,6 +528,8 @@ def test_forward_gather_nd(): ...@@ -528,6 +528,8 @@ def test_forward_gather_nd():
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]]) verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
def test_forward_bilinear_resize(): def test_forward_bilinear_resize():
# add tests including scale_height and scale_width when mxnet is updated to version 1.5 # add tests including scale_height and scale_width when mxnet is updated to version 1.5
......
...@@ -682,7 +682,8 @@ def test_gather_nd(): ...@@ -682,7 +682,8 @@ def test_gather_nd():
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]]) verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
if __name__ == "__main__": if __name__ == "__main__":
test_arange() test_arange()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment