Commit 7ce6a41d by Huilin Qu Committed by Tianqi Chen

Fix gather_nd in Relay (#3442)

* Fix gather_nd in Relay

* Add test cases for gather_nd.
parent 05c528c1
......@@ -2332,7 +2332,7 @@ bool GatherNDRel(const Array<Type>& types,
return false;
}
const size_t ndim = data->shape.size();
const IntImm* mdim = data->shape[0].as<IntImm>();
const IntImm* mdim = indices->shape[0].as<IntImm>();
const size_t kdim = indices->shape.size() - 1;
CHECK(size_t(mdim->value) <= ndim)
<< "GatherND: indices shape does satisfy.";
......
......@@ -528,6 +528,8 @@ def test_forward_gather_nd():
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
def test_forward_bilinear_resize():
# add tests including scale_height and scale_width when mxnet is updated to version 1.5
......
......@@ -682,7 +682,8 @@ def test_gather_nd():
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
if __name__ == "__main__":
test_arange()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment