Unverified Commit 1562eaeb by Yizhi Liu Committed by GitHub

[TOPI] Fix flaky testcase for floor div (#4382)

* [TOPI] Fix flaky testcase for floor div

* avoid check at 0.0
parent 0cfa3a80
......@@ -62,6 +62,18 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
assert(isinstance(C, tvm.expr.Expr))
return
def gen_operand(shape, low, high, ctx):
if shape is None:
npy = float(np.random.uniform(low=low, high=high))
if dtype.startswith('int'):
npy = int(npy)
nd = npy
else:
npy = np.random.uniform(low=low, high=high,
size=shape).astype(dtype)
nd = tvm.nd.array(npy, ctx)
return npy, nd
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
......@@ -71,27 +83,24 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + ftopi.__name__)
if lhs_shape is None:
lhs_npy = float(np.random.uniform(low=lhs_min, high=lhs_max))
if dtype.startswith('int'):
lhs_npy = int(lhs_npy)
lhs_nd = lhs_npy
else:
lhs_npy = np.random.uniform(low=lhs_min, high=lhs_max,
size=lhs_shape).astype(A.dtype)
lhs_nd = tvm.nd.array(lhs_npy, ctx)
if rhs_shape is None:
rhs_npy = float(np.random.uniform(low=rhs_min, high=rhs_max))
if dtype.startswith('int'):
rhs_npy = int(rhs_npy)
rhs_nd = rhs_npy
else:
rhs_npy = np.random.uniform(low=rhs_min, high=rhs_max,
size=rhs_shape).astype(A.dtype)
rhs_nd = tvm.nd.array(rhs_npy, ctx)
lhs_npy, lhs_nd = gen_operand(lhs_shape, lhs_min, lhs_max, ctx)
rhs_npy, rhs_nd = gen_operand(rhs_shape, rhs_min, rhs_max, ctx)
out_npy = fnumpy(lhs_npy, rhs_npy)
if fnumpy == np.floor_divide:
# avoid check too close to X.5 and X.0
# FIXME: floor_divide(94.90735, 0.6731018) behaves as floor(div(94.90735, 0.6731018))
# However the result is somehow incorrect - need to further investigate.
# And looks like numpy's floor_div(a,b) is implemented different from floor(div(a,b))
mask = np.logical_or(np.abs(np.abs(np.fmod(lhs_npy / rhs_npy, 1)) - 0.5) < 1e-6,
np.abs(np.fmod(lhs_npy / rhs_npy, 1)) < 1e-6)
if mask.any():
lhs_npy = lhs_npy + mask * 1e-3 * rhs_npy
lhs_npy = lhs_npy.astype(dtype)
lhs_nd = tvm.nd.array(lhs_npy, ctx) if lhs_shape is not None else lhs_npy.item()
out_npy = fnumpy(lhs_npy, rhs_npy)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
foo(lhs_nd, rhs_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
......@@ -145,7 +154,7 @@ def test_floor_divide():
verify_broadcast_binary_ele(
(), None, topi.floor_divide, np.floor_divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 1, 32), (64, 32), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
(2, 3, 64, 32), (64, 32), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
def test_maximum_minmum():
verify_broadcast_binary_ele(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment