Commit f607d46c by Yang Chen Committed by Tianqi Chen

relax rtol/atol checks on some onnx tests (#2403)

relax the error constraints on these tests due to likely
FP compuation accuracy issues.
parent e36265bb
......@@ -272,7 +272,7 @@ def test_slice():
_test_slice_iteration(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1))
def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs, rtol=1e-7, atol=1e-7):
indata = np.random.uniform(-1, 1, size=inshape).astype(dtype)
outdata = outfunc(indata, **npargs)
......@@ -290,7 +290,7 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)
tvm.testing.assert_allclose(outdata, tvm_out)
tvm.testing.assert_allclose(outdata, tvm_out, rtol=rtol, atol=atol)
def test_floor():
_test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, 'float32', 'Floor', {})
......@@ -863,7 +863,7 @@ def test_binary_ops():
dtype = "float32"
out_shape = in_shape
def verify_binary_ops(op, x, y, out_np, broadcast=None):
def verify_binary_ops(op, x, y, out_np, broadcast=None, rtol=1e-7, atol=1e-7):
if broadcast is None:
z = helper.make_node(op, ['in1', 'in2'], ['out'])
else:
......@@ -879,7 +879,7 @@ def test_binary_ops():
model = helper.make_model(graph, producer_name='_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, y], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
x = np.random.uniform(size=in_shape).astype(dtype)
y = np.random.uniform(size=in_shape).astype(dtype)
......@@ -890,8 +890,8 @@ def test_binary_ops():
verify_binary_ops("Sub", x, z, x - z, broadcast=True)
verify_binary_ops("Mul",x, y, x * y, broadcast=None)
verify_binary_ops("Mul", x, z, x * z, broadcast=True)
verify_binary_ops("Div", x, y, x / y, broadcast=None)
verify_binary_ops("Div", x, z, x / z, broadcast=True)
verify_binary_ops("Div", x, y, x / y, broadcast=None, rtol=1e-5, atol=1e-5)
verify_binary_ops("Div", x, z, x / z, broadcast=True, rtol=1e-5, atol=1e-5)
verify_binary_ops("Sum", x, y, x + y, broadcast=None)
def test_single_ops():
......@@ -899,7 +899,7 @@ def test_single_ops():
dtype = "float32"
out_shape = in_shape
def verify_single_ops(op, x, out_np):
def verify_single_ops(op, x, out_np, rtol=1e-7, atol=1e-7):
z = helper.make_node(op, ['in1'], ['out'])
graph = helper.make_graph([z],
'_test',
......@@ -915,8 +915,8 @@ def test_single_ops():
x = np.random.uniform(size=in_shape).astype(dtype)
verify_single_ops("Neg",x, -x)
verify_single_ops("Abs",x, np.abs(x))
verify_single_ops("Reciprocal",x, 1/x)
verify_single_ops("Sqrt",x, np.sqrt(x))
verify_single_ops("Reciprocal",x, 1/x, rtol=1e-5, atol=1e-5)
verify_single_ops("Sqrt",x, np.sqrt(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Relu",x, np.maximum(x, 0))
verify_single_ops("Exp",x, np.exp(x))
verify_single_ops("Log",x, np.log(x))
......@@ -1004,7 +1004,9 @@ def test_LogSoftmax():
{},
'float32',
'LogSoftmax',
{'axis': 1})
{'axis': 1},
rtol=1e-5,
atol=1e-5)
if __name__ == '__main__':
# verify_super_resolution_example()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment