Commit 3ff2d958 by Xingjian Shi Committed by Tianqi Chen

try to fix test (#784)

try to fix

fix
parent 8a4e1d49
......@@ -54,7 +54,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
with tvm.target.create(device):
s = topi.generic.schedule_reduce(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="sum")
foo = tvm.build(s, [A, B], device, name=type)
# Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
in_npy_map = np.sqrt(np.exp(in_npy)).astype(np.float32)
......@@ -74,6 +74,21 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
for _ in range(1):
foo(data_tvm, out_tvm)
if type == "argmax" or type == "argmin":
out_tvm_indices = out_tvm.asnumpy()
if keepdims:
out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis)
if axis is None:
out_tvm_val = in_npy_map.ravel()[out_tvm_indices]
else:
other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis+1):]))
sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:]
out_tvm_val = in_npy_map[sel_indices]
if type == "argmax":
np.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1E-3, 1E-3)
elif type == "argmin":
np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
for device in ["cuda", "opencl", "metal", "llvm", "rocm"]:
check_device(device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment