Commit 3d62cf7c by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] More type alpha equality test coverage (#1823)

parent 5bf1cbda
......@@ -25,6 +25,7 @@ TypeParam = ty.TypeParam
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
# Expr
Constant = expr.Constant
......
......@@ -88,11 +88,23 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
void VisitType_(const FuncTypeNode *op, const Type& t2) final {
if (const FuncTypeNode *ta2 = t2.as<FuncTypeNode>()) {
if (op->arg_types.size() != ta2->arg_types.size()) {
if (op->arg_types.size() != ta2->arg_types.size()
|| op->type_params.size() != ta2->type_params.size()
|| op->type_constraints.size() != ta2->type_constraints.size()) {
equal = false;
return;
}
// must visit params first so they are appropriate entered
// into equality map
for (size_t i = 0; i < op->type_params.size(); i++) {
eq_map.Set(op->type_params[i], ta2->type_params[i]);
this->VisitType(op->type_params[i], ta2->type_params[i]);
if (!equal) {
return;
}
}
for (size_t i = 0; i < op->arg_types.size(); i++) {
this->VisitType(op->arg_types[i], ta2->arg_types[i]);
if (!equal) {
......@@ -101,6 +113,16 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}
this->VisitType(op->ret_type, ta2->ret_type);
if (!equal) {
return;
}
for (size_t i = 0; i < op->type_constraints.size(); i++) {
this->VisitType(op->type_constraints[i], ta2->type_constraints[i]);
if (!equal) {
return;
}
}
} else {
equal = false;
}
......@@ -108,7 +130,24 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
void VisitType_(const TypeRelationNode *tr1, const Type& t2) final {
if (const TypeRelationNode *tr2 = t2.as<TypeRelationNode>()) {
equal = tr1 == tr2;
if (tr1->func != tr2->func
|| tr1->num_inputs != tr2->num_inputs
|| tr1->attrs != tr2->attrs) {
equal = false;
return;
}
if (tr1->args.size() != tr2->args.size()) {
equal = false;
return;
}
for (size_t i = 0; i < tr1->args.size(); i++) {
this->VisitType(tr1->args[i], tr2->args[i]);
if (!equal) {
return;
}
}
} else {
equal = false;
}
......
......@@ -65,8 +65,8 @@ def test_type_relation():
args = tvm.convert([tf, tt, tp])
num_inputs = 2
func = None
attrs = None
func = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
tr = relay.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
......
import tvm
from tvm import relay
def test_type_alpha_eq():
t1 = relay.ty.TensorType((3, 4), "float32")
t2 = relay.ty.TensorType((3, 4), "float32")
t3 = relay.ty.TensorType((3, 4, 5), "float32")
def test_tensor_type_alpha_eq():
t1 = relay.TensorType((3, 4), "float32")
t2 = relay.TensorType((3, 4), "float32")
t3 = relay.TensorType((3, 4, 5), "float32")
assert t1 == t2
assert t1 != t3
t1 = relay.ty.TensorType((), "float32")
t2 = relay.ty.TensorType((), "float32")
t1 = relay.TensorType((), "float32")
t2 = relay.TensorType((), "float32")
assert t1 == t2
def test_incomplete_type_alpha_eq():
t1 = relay.IncompleteType(relay.Kind.Shape)
t2 = relay.IncompleteType(relay.Kind.Type)
t3 = relay.IncompleteType(relay.Kind.Type)
# only equal when there is pointer equality
assert t2 == t2
assert t1 == t1
assert t1 != t2
assert t2 != t3
def test_type_param_alpha_eq():
t1 = relay.TypeParam("v1", relay.Kind.Type)
t2 = relay.TypeParam("v2", relay.Kind.Shape)
t3 = relay.TypeParam("v3", relay.Kind.Type)
# only pointer equality and eq_map allow equal params
assert t1 == t1
assert t2 == t2
assert t1 != t2 # different kind
assert t1 != t3 # not in eq_map
# function types are the only way to put type params
# in eq map
ft1 = relay.FuncType(tvm.convert([]), t1, tvm.convert([t1]), tvm.convert([]))
ft2 = relay.FuncType(tvm.convert([]), t3, tvm.convert([t3]), tvm.convert([]))
# actually an invalid type because t2 is wrong kind
ft3 = relay.FuncType(tvm.convert([]), t2, tvm.convert([t2]), tvm.convert([]))
assert ft1 == ft2
assert ft1 != ft3 # kinds still do not match
def test_func_type_alpha_eq():
t1 = relay.TensorType((1, 2), "float32")
t2 = relay.TensorType((1, 2, 3), "float32")
tp1 = relay.TypeParam("v1", relay.Kind.Type)
tp2 = relay.TypeParam("v2", relay.Kind.Type)
tp3 = relay.TypeParam("v3", relay.Kind.Shape)
tp4 = relay.TypeParam("v3", relay.Kind.Shape)
broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
identity = tvm.get_env_func("tvm.relay.type_relation.Identity")
tr1 = relay.TypeRelation(broadcast, tvm.convert([tp1, tp3]), 1, None)
tr2 = relay.TypeRelation(broadcast, tvm.convert([tp2, tp4]), 1, None)
tr3 = relay.TypeRelation(identity, tvm.convert([tp1, tp3]), 1, None)
ft = relay.FuncType(tvm.convert([t1, t2]), tp1,
tvm.convert([tp1, tp3]),
tvm.convert([tr1]))
translate_vars = relay.FuncType(tvm.convert([t1, t2]), tp1,
tvm.convert([tp2, tp4]),
tvm.convert([tr2]))
assert ft == translate_vars
different_args = relay.FuncType(tvm.convert([t1]), tp1,
tvm.convert([tp1, tp3]),
tvm.convert([tr1]))
assert ft != different_args
different_order = relay.FuncType(tvm.convert([t2, t1]), tp1,
tvm.convert([tp1, tp3]),
tvm.convert([tr1]))
assert ft != different_order
no_rel = relay.FuncType(tvm.convert([t1, t2]), tp1,
tvm.convert([tp1, tp3]),
tvm.convert([]))
assert ft != no_rel
more_vars = relay.FuncType(tvm.convert([t1, t2]), tp2,
tvm.convert([tp1, tp2, tp3]),
tvm.convert([tr1]))
assert ft != more_vars
all_the_vars = relay.FuncType(tvm.convert([t1, t2]), tp1,
tvm.convert([tp1, tp2, tp3, tp4]),
tvm.convert([tr1, tr2]))
assert ft != all_the_vars
different_rel = relay.FuncType(tvm.convert([t1, t2]), tp1,
tvm.convert([tp1, tp3]),
tvm.convert([tr3]))
assert ft != different_rel
more_rels = relay.FuncType(tvm.convert([t1, t2]), tp1,
tvm.convert([tp1, tp3]),
tvm.convert([tr1, tr3]))
assert ft != more_rels
def test_tuple_type_alpha_eq():
t1 = relay.TensorType((1, 2, 3), "float32")
t2 = relay.TensorType((1, 2, 3, 4), "float32")
tp1 = relay.TypeParam("v1", relay.Kind.Type)
tp2 = relay.TypeParam("v2", relay.Kind.Type)
tup1 = relay.TupleType(tvm.convert([t1, t2, tp1]))
tup2 = relay.TupleType(tvm.convert([t1, t2, tp1]))
tup3 = relay.TupleType(tvm.convert([t2, t1, tp1]))
tup4 = relay.TupleType(tvm.convert([t1, t2, tp2]))
# as long as types are alpha-equal and in same order,
# tuples should be alpha-equal
assert tup1 == tup2
assert tup1 != tup3
assert tup1 != tup4
def test_type_relation_alpha_eq():
t1 = relay.TensorType((1, 2), "float32")
t2 = relay.TensorType((1, 2, 3), "float32")
t3 = relay.TensorType((1, 2, 3, 4), "float32")
# functions are compared only by pointer equality so
# we need to be sure to use the same pointers
broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
identity = tvm.get_env_func("tvm.relay.type_relation.Identity")
# attrs are also compared only by pointer equality
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
diff_func = relay.TypeRelation(identity, tvm.convert([t1, t2]), 1, attr1)
diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1)
diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1)
diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2)
bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1)
diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2)
# func, number of args, input count, and order should be the same
assert tr == same
assert tr != diff_func
assert tr != diff_order
assert tr != diff_args
assert tr != diff_attr
assert tr != bigger
assert bigger != diff_num_inputs
if __name__ == "__main__":
test_type_alpha_eq()
test_tensor_type_alpha_eq()
test_incomplete_type_alpha_eq()
test_type_param_alpha_eq()
test_func_type_alpha_eq()
test_tuple_type_alpha_eq()
test_type_relation_alpha_eq()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment