Commit 4e21982b by Wu Zhao Committed by Tianqi Chen

Add test case of argmax for detecting out of bound access (#2234)

parent 374918fa
......@@ -686,6 +686,28 @@ def test_where():
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
def test_argmax():
dshape = (204800, 2)
oshape = (1, 320, 640)
dtype = "float32"
x = sym.Variable("x", shape=dshape, dtype=dtype)
x = sym.reshape(x, shape=(1, 320, 640, 2))
x = sym.transpose(x, axes=(0, 3, 1, 2))
y = sym.argmax(x, axis=1)
target_str = "llvm"
target = tvm.target.create(target_str)
ctx = tvm.context(target_str, 0)
with nnvm.compiler.build_config(opt_level=2):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = graph_runtime.create(graph, lib, ctx)
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
np_reshape = np.reshape(data, (1, 320, 640, 2))
np_transpose = np.transpose(np_reshape, axes=(0, 3, 1, 2))
np_argmax = np.argmax(np_transpose, axis=1)
out = m.get_output(0)
np.testing.assert_allclose(out.asnumpy(), np_argmax, atol=1e-5, rtol=1e-5)
if __name__ == "__main__":
test_reshape()
......@@ -707,4 +729,5 @@ if __name__ == "__main__":
test_nms()
test_slice_like()
test_where()
test_argmax()
print(nnvm.compiler.engine.dump())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment