Commit 2b045c56 by Haichen Shen Committed by Tianqi Chen

[TEST][FLAKY] Fix flaky test on topk and quantize pass (#3362)

* fix flaky test

* fix flaky quantize pass
parent 7bf2ff23
......@@ -80,12 +80,12 @@ def test_topk():
tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
else:
tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
np.random.seed(0)
for k in [0, 1, 5]:
for axis in [0, -1, 1]:
for ret_type in ["both", "values", "indices"]:
for dtype in ["int64", "float32"]:
verify_topk(k, axis, ret_type, False, dtype)
verify_topk(k, axis, ret_type, True, dtype)
verify_topk(k, axis, ret_type, True, "int64")
verify_topk(k, axis, ret_type, False, "float32")
if __name__ == "__main__":
......
......@@ -75,6 +75,8 @@ def test_quantize_pass():
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out
np.random.seed(42)
data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
graph = make_graph(data)
dataset, params = make_dataset(graph, 10)
......@@ -95,6 +97,5 @@ def test_quantize_pass():
if __name__ == "__main__":
np.random.seed(42)
test_simulated_quantize()
test_quantize_pass()
......@@ -96,12 +96,12 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype):
check_device(device)
def test_topk():
np.random.seed(0)
for k in [0, 1, 5]:
for axis in [0, -1, 1]:
for ret_type in ["both", "values", "indices"]:
for dtype in ["int64", "float32"]:
verify_topk(k, axis, ret_type, True, dtype)
verify_topk(k, axis, ret_type, False, dtype)
verify_topk(k, axis, ret_type, True, "int64")
verify_topk(k, axis, ret_type, False, "float32")
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment