Commit 24e51abc by Josh Pollock Committed by Tianqi Chen

[Relay] Nullable Type Alpha Equality (#1906)

parent f2b30f9e
...@@ -193,6 +193,12 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { ...@@ -193,6 +193,12 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}; };
bool AlphaEqual(const Type& t1, const Type& t2) { bool AlphaEqual(const Type& t1, const Type& t2) {
if (t1.defined() != t2.defined())
return false;
if (!t1.defined())
return true;
TypeAlphaEq aeq; TypeAlphaEq aeq;
aeq.VisitType(t1, t2); aeq.VisitType(t1, t2);
return aeq.equal; return aeq.equal;
...@@ -373,15 +379,11 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -373,15 +379,11 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
private: private:
void MergeVarDecl(const Var& var1, const Var& var2) { void MergeVarDecl(const Var& var1, const Var& var2) {
if (var1->type_annotation.defined() != var2->type_annotation.defined()) { equal = equal && AlphaEqual(var1->type_annotation, var2->type_annotation);
equal = false; if (!equal) {
return;
}
if (var1->type_annotation.defined() &&
!AlphaEqual(var1->type_annotation, var2->type_annotation)) {
equal = false;
return; return;
} }
eq_map.Set(var1, var2); eq_map.Set(var1, var2);
} }
}; };
......
...@@ -187,6 +187,25 @@ def test_var_alpha_equal(): ...@@ -187,6 +187,25 @@ def test_var_alpha_equal():
assert alpha_equal(l1, l2) assert alpha_equal(l1, l2)
assert not alpha_equal(l1, l3) assert not alpha_equal(l1, l3)
# type annotations
tt1 = relay.TensorType([], "int32")
tt2 = relay.TensorType([], "int32")
tt3 = relay.TensorType([], "int64")
v3 = relay.Var("v3", tt1)
v4 = relay.Var("v4", tt2)
v5 = relay.Var("v5", tt3)
l4 = relay.Let(v3, convert(1), v3)
l5 = relay.Let(v4, convert(1), v4)
l6 = relay.Let(v5, convert(1), v5)
# same annotations
assert alpha_equal(l4, l5)
# different annotations
assert not alpha_equal(l4, l6)
# one null annotation
assert not alpha_equal(l1, l4)
def test_global_var_alpha_equal(): def test_global_var_alpha_equal():
v1 = relay.GlobalVar("v1") v1 = relay.GlobalVar("v1")
...@@ -307,6 +326,14 @@ def test_function_alpha_equal(): ...@@ -307,6 +326,14 @@ def test_function_alpha_equal():
tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3) tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
assert not alpha_equal(func, tupled_example) assert not alpha_equal(func, tupled_example)
# nullable
no_ret_type = relay.Function(basic_args, v4, None, [tp1, tp2])
# both null
assert alpha_equal(no_ret_type, no_ret_type)
# one null
assert not alpha_equal(func, no_ret_type)
assert not alpha_equal(no_ret_type, func)
def test_call_alpha_equal(): def test_call_alpha_equal():
v1 = relay.Var("v1") v1 = relay.Var("v1")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment