Commit a5be8fd3 by Zhi Committed by Tianqi Chen

[Relay] remove redundant test cases in test_op_level4.py (#1905)

parent b7a8af8d
......@@ -13,37 +13,6 @@ def assert_has_type(expr, typ, env=Environment({})):
raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ))
def test_cmp_type():
for op in (relay.greater,
relay.greater_equal,
relay.less,
relay.less_equal,
relay.equal,
relay.not_equal):
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "float32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
with ib.function(x, y) as func:
ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1")
def test_binary_broadcast():
for op in [relay.right_shift,
relay.left_shift,
relay.maximum]:
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
with ib.function(x, y) as func:
ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")
def test_binary_op():
def check_binary_op(opfunc):
......@@ -138,9 +107,9 @@ def test_where():
if __name__ == "__main__":
test_cmp_type()
test_binary_broadcast()
test_binary_op()
test_binary_broadcast_op()
test_cmp_type()
test_binary_broadcast()
test_where()
test_multibox_prior()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment