Commit df6f54ac by Steven S. Lyubomirsky Committed by Jared Roesch

Add type solver unit tests for unifying quantified funcs (one bug found) (#3947)

parent ba4d081c
......@@ -224,6 +224,49 @@ def test_backward_solving_after_child_update():
assert solver.Resolve(t5) == tup_concrete
def test_unify_quantified_funcs():
solver = make_solver()
a, b, c = relay.TypeVar('a'), relay.TypeVar('b'), relay.TypeVar('c')
ft1 = relay.FuncType([a, b], c, [a, b, c])
ft2 = relay.FuncType([a, a], a, [a])
unified = solver.Unify(ft1, ft2)
assert unified == ft2
ft3 = relay.FuncType([a], a, [a])
ft4 = relay.FuncType([b], c, [b, c])
unified = solver.Unify(ft3, ft4)
assert unified == ft3
def test_unify_quantified_func_and_concrete():
solver = make_solver()
a, b = relay.TypeVar('a'), relay.TypeVar('b')
ft1 = relay.FuncType([a], b, [a, b])
ft2 = relay.FuncType([b], relay.TupleType([]), [b])
unified = solver.Unify(ft1, ft2)
assert unified == ft2
def test_unify_quantified_funcs_nesting():
solver = make_solver()
a, b, c = relay.TypeVar('a'), relay.TypeVar('b'), relay.TypeVar('c')
ft1 = relay.FuncType([a, relay.TupleType([b, c])], relay.TupleType([a, b, c]), [a, b, c])
ft2 = relay.FuncType([a, relay.TupleType([a, a])], relay.TupleType([a, a, a]), [a])
unified = solver.Unify(ft1, ft2)
assert unified == ft2
def test_unify_quantified_funcs_var_order():
solver = make_solver()
a, b, c = relay.TypeVar('a'), relay.TypeVar('b'), relay.TypeVar('c')
ft1 = relay.FuncType([a, relay.TupleType([b, c])], relay.TupleType([a, b, c]), [a, b, c])
ft2 = relay.FuncType([a, relay.TupleType([a, c])], relay.TupleType([a, a, c]), [a, c])
# unified = solver.Unify(ft1, ft2) # crashes here but it shouldn't
# assert unified == ft2
@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_incompatible_tuple_unification():
solver = make_solver()
......@@ -284,6 +327,16 @@ def test_incompatible_typecall_args_unification():
solver.Unify(tc1, tc2)
@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_incompatible_quantified_func_unification():
solver = make_solver()
a, b, c = relay.TypeVar('a'), relay.TypeVar('b'), relay.TypeVar('c')
ft1 = relay.FuncType([a, b], c, [a, b, c])
ft2 = relay.FuncType([b, c], relay.TupleType([a]), [a, b, c])
solver.Unify(ft1, ft2)
if __name__ == "__main__":
test_bcast()
test_backward_solving()
......@@ -294,7 +347,12 @@ if __name__ == "__main__":
test_unify_vars_under_tuples()
test_recursive_backward_solving()
test_backward_solving_after_child_update()
test_unify_quantified_funcs()
test_unify_quantified_func_and_concrete()
test_unify_quantified_funcs_nesting()
test_unify_quantified_funcs_var_order()
test_incompatible_tuple_unification()
test_bad_recursive_unification()
test_incompatible_typecall_var_unification()
test_incompatible_typecall_args_unification()
test_incompatible_quantified_func_unification()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment