Commit 3455c8a5 by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Incorporate TypeRelations into more tests (#1792)

parent 106991d2
......@@ -21,6 +21,7 @@ Kind = ty.Kind
TypeParam = ty.TypeParam
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
# Expr
Constant = expr.Constant
......
......@@ -158,3 +158,26 @@ class IncompleteType(Type):
def __init__(self, kind):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : list of types
List of types to the func.
num_inputs: int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
......@@ -45,8 +45,7 @@ struct KindChecker : TypeVisitor<> {
return true;
}
return t.as<TensorTypeNode>() || t.as<BaseTensorTypeNode>()
|| t.as<TupleTypeNode>() || t.as<FuncTypeNode>();
return t.as_derived<BaseTensorTypeNode>() || t.as<TupleTypeNode>() || t.as<FuncTypeNode>();
}
void VisitType_(const TupleTypeNode* op) override {
......@@ -61,8 +60,9 @@ struct KindChecker : TypeVisitor<> {
}
void VisitType_(const FuncTypeNode* op) override {
// func types should only take normal types for arguments
// and only return a normal type
// Func types should only take normal types for arguments
// and only return a normal type. They should also have
// well-formed constraints
for (const Type& t : op->arg_types) {
this->VisitType(t);
valid = valid && IsTypeKind(t);
......@@ -71,6 +71,13 @@ struct KindChecker : TypeVisitor<> {
}
}
for (const TypeConstraint& tc : op->type_constraints) {
this->VisitType(tc);
if (!valid) {
return;
}
}
this->VisitType(op->ret_type);
valid = valid && IsTypeKind(op->ret_type);
}
......
......@@ -2,7 +2,7 @@ import tvm
from tvm import relay
from tvm.relay.ir_pass import check_kind
def test_tuple_kinds():
def test_tuple_kind():
# only contain type kinds
tp = relay.TypeParam('tp', relay.Kind.Type)
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
......@@ -12,6 +12,7 @@ def test_tuple_kinds():
tup_ty = relay.TupleType(fields)
assert check_kind(tup_ty)
def test_func_kind():
# only contain type kinds
tp1 = relay.TypeParam('tp1', relay.Kind.Type)
......@@ -21,15 +22,29 @@ def test_func_kind():
dtype = 'float32'
tensor_type = relay.TensorType(shape, dtype)
tr = relay.TypeRelation(None, tvm.convert([tensor_type, tp1]) , 1, None)
type_params = tvm.convert([tp1, tp2])
type_constraints = tvm.convert([])
type_constraints = tvm.convert([tr])
arg_types = tvm.convert([tp1, tensor_type])
ret_type = relay.TupleType(tvm.convert([tp2, tensor_type]))
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert check_kind(tf)
def test_invalid_tuple_kinds():
def test_relation_kind():
# only have type kinds for arguments
tp = relay.TypeParam('tp', relay.Kind.Type)
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([]))
args = tvm.convert([tf, tt, tp])
tr = relay.TypeRelation(None, args, 2, None)
assert check_kind(tr)
def test_invalid_tuple_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
......@@ -38,6 +53,7 @@ def test_invalid_tuple_kinds():
tup_ty = relay.TupleType(fields)
assert not check_kind(tup_ty)
def test_invalid_func_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
......@@ -51,16 +67,29 @@ def test_invalid_func_kind():
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert not check_kind(tf)
def test_invalid_relation_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
args = tvm.convert([tp1, tp2, tp3])
tr = relay.TypeRelation(None, args, 2, None)
assert not check_kind(tr)
def test_func_with_invalid_ret_type():
tp1 = relay.TypeParam('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Shape)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_func_with_invalid_arg_types():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_func_with_invalid_tuple():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
......@@ -69,6 +98,18 @@ def test_func_with_invalid_tuple():
tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([]))
assert not check_kind(tf)
def test_func_with_invalid_relation():
tp1 = relay.TypeParam('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Shape)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None)
tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr]))
assert not check_kind(tf)
def test_tuple_with_invalid_func():
tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
......@@ -77,3 +118,17 @@ def test_tuple_with_invalid_func():
tup_ty = relay.TupleType(tvm.convert([tensor_type, tf]))
assert not check_kind(tup_ty)
if __name__ == "__main__":
test_tuple_kind()
test_func_kind()
test_relation_kind()
test_invalid_tuple_kind()
test_invalid_func_kind()
test_invalid_relation_kind()
test_func_with_invalid_ret_type()
test_func_with_invalid_arg_types()
test_func_with_invalid_tuple()
test_func_with_invalid_relation()
test_tuple_with_invalid_func()
......@@ -58,6 +58,21 @@ def test_tuple_type():
assert tup_ty.fields == fields
def test_type_relation():
tp = relay.TypeParam('tp', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
args = tvm.convert([tf, tt, tp])
num_inputs = 2
func = None
attrs = None
tr = relay.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
assert tr.num_inputs == num_inputs
def test_constant():
arr = tvm.nd.array(10)
const = relay.Constant(arr)
......@@ -158,6 +173,8 @@ if __name__ == "__main__":
test_tensor_type()
test_type_param()
test_func_type()
test_tuple_type()
test_type_relation()
test_constant()
test_tuple()
test_local_var()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment