Unverified Commit 8290eaba by Neo Chien Committed by GitHub

[TEST][FLAKY] topi/tests/python/test_topi_sort.py::test_argsort (#4891)

* [TEST][FLAKY] topi/tests/python/test_topi_sort.py::test_argsort

* upadate test function of argsort like topk

* Shuffle index and get data from shuffled index

* Replace the random.uniform with np.arange
parent f47c38db
...@@ -21,11 +21,26 @@ import tvm ...@@ -21,11 +21,26 @@ import tvm
import topi import topi
import topi.testing import topi.testing
def test_argsort():
def verify_argsort(axis, is_ascend):
dshape = (20, 100) dshape = (20, 100)
data = tvm.placeholder(dshape, name="data", dtype="float32") data_dtype = "float32"
np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype) data = tvm.placeholder(dshape, name="data", dtype=data_dtype)
np_result = np.argsort(-np_data)
perm = np.arange(dshape[0] * dshape[1], dtype=data_dtype)
np.random.shuffle(perm)
np_data = perm.reshape(dshape)
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
else:
np_indices = np.argsort(-np_data, axis=axis)
if axis == 0:
np_indices = np_indices[:dshape[axis], :]
else:
np_indices = np_indices[:, :dshape[axis]]
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
...@@ -33,18 +48,19 @@ def test_argsort(): ...@@ -33,18 +48,19 @@ def test_argsort():
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
out = topi.argsort(data, axis=-1, is_ascend=False) out = topi.argsort(data, axis=axis, is_ascend=is_ascend)
s = topi.generic.schedule_argsort(out) s = topi.generic.schedule_argsort(out)
tvm_data = tvm.nd.array(np_data, ctx) tvm_data = tvm.nd.array(np_data, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data_dtype), ctx)
f = tvm.build(s, [data, out], device) f = tvm.build(s, [data, out], device)
f(tvm_data, tvm_out) f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_indices.astype(data_dtype), rtol=1e0)
for device in ['llvm', 'cuda', 'opencl']: for device in ['llvm', 'cuda', 'opencl']:
check_device(device) check_device(device)
def verify_topk(k, axis, ret_type, is_ascend, dtype): def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100) shape = (20, 100)
data_dtype = "float32" data_dtype = "float32"
...@@ -95,6 +111,14 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): ...@@ -95,6 +111,14 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype):
for device in ['llvm', 'cuda', 'opencl']: for device in ['llvm', 'cuda', 'opencl']:
check_device(device) check_device(device)
def test_argsort():
np.random.seed(0)
for axis in [0, -1, 1]:
verify_argsort(axis, True)
verify_argsort(axis, False)
def test_topk(): def test_topk():
np.random.seed(0) np.random.seed(0)
for k in [0, 1, 5]: for k in [0, 1, 5]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment