Commit 3455c8a5 by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Incorporate TypeRelations into more tests (#1792)

parent 106991d2
...@@ -21,6 +21,7 @@ Kind = ty.Kind ...@@ -21,6 +21,7 @@ Kind = ty.Kind
TypeParam = ty.TypeParam TypeParam = ty.TypeParam
TypeConstraint = ty.TypeConstraint TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
# Expr # Expr
Constant = expr.Constant Constant = expr.Constant
......
...@@ -158,3 +158,26 @@ class IncompleteType(Type): ...@@ -158,3 +158,26 @@ class IncompleteType(Type):
def __init__(self, kind): def __init__(self, kind):
self.__init_handle_by_constructor__(_make.IncompleteType, kind) self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : list of types
List of types to the func.
num_inputs: int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
...@@ -45,8 +45,7 @@ struct KindChecker : TypeVisitor<> { ...@@ -45,8 +45,7 @@ struct KindChecker : TypeVisitor<> {
return true; return true;
} }
return t.as<TensorTypeNode>() || t.as<BaseTensorTypeNode>() return t.as_derived<BaseTensorTypeNode>() || t.as<TupleTypeNode>() || t.as<FuncTypeNode>();
|| t.as<TupleTypeNode>() || t.as<FuncTypeNode>();
} }
void VisitType_(const TupleTypeNode* op) override { void VisitType_(const TupleTypeNode* op) override {
...@@ -61,8 +60,9 @@ struct KindChecker : TypeVisitor<> { ...@@ -61,8 +60,9 @@ struct KindChecker : TypeVisitor<> {
} }
void VisitType_(const FuncTypeNode* op) override { void VisitType_(const FuncTypeNode* op) override {
// func types should only take normal types for arguments // Func types should only take normal types for arguments
// and only return a normal type // and only return a normal type. They should also have
// well-formed constraints
for (const Type& t : op->arg_types) { for (const Type& t : op->arg_types) {
this->VisitType(t); this->VisitType(t);
valid = valid && IsTypeKind(t); valid = valid && IsTypeKind(t);
...@@ -71,6 +71,13 @@ struct KindChecker : TypeVisitor<> { ...@@ -71,6 +71,13 @@ struct KindChecker : TypeVisitor<> {
} }
} }
for (const TypeConstraint& tc : op->type_constraints) {
this->VisitType(tc);
if (!valid) {
return;
}
}
this->VisitType(op->ret_type); this->VisitType(op->ret_type);
valid = valid && IsTypeKind(op->ret_type); valid = valid && IsTypeKind(op->ret_type);
} }
......
...@@ -2,7 +2,7 @@ import tvm ...@@ -2,7 +2,7 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import check_kind from tvm.relay.ir_pass import check_kind
def test_tuple_kinds(): def test_tuple_kind():
# only contain type kinds # only contain type kinds
tp = relay.TypeParam('tp', relay.Kind.Type) tp = relay.TypeParam('tp', relay.Kind.Type)
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
...@@ -12,6 +12,7 @@ def test_tuple_kinds(): ...@@ -12,6 +12,7 @@ def test_tuple_kinds():
tup_ty = relay.TupleType(fields) tup_ty = relay.TupleType(fields)
assert check_kind(tup_ty) assert check_kind(tup_ty)
def test_func_kind(): def test_func_kind():
# only contain type kinds # only contain type kinds
tp1 = relay.TypeParam('tp1', relay.Kind.Type) tp1 = relay.TypeParam('tp1', relay.Kind.Type)
...@@ -21,15 +22,29 @@ def test_func_kind(): ...@@ -21,15 +22,29 @@ def test_func_kind():
dtype = 'float32' dtype = 'float32'
tensor_type = relay.TensorType(shape, dtype) tensor_type = relay.TensorType(shape, dtype)
tr = relay.TypeRelation(None, tvm.convert([tensor_type, tp1]) , 1, None)
type_params = tvm.convert([tp1, tp2]) type_params = tvm.convert([tp1, tp2])
type_constraints = tvm.convert([]) type_constraints = tvm.convert([tr])
arg_types = tvm.convert([tp1, tensor_type]) arg_types = tvm.convert([tp1, tensor_type])
ret_type = relay.TupleType(tvm.convert([tp2, tensor_type])) ret_type = relay.TupleType(tvm.convert([tp2, tensor_type]))
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert check_kind(tf) assert check_kind(tf)
def test_invalid_tuple_kinds():
def test_relation_kind():
# only have type kinds for arguments
tp = relay.TypeParam('tp', relay.Kind.Type)
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([]))
args = tvm.convert([tf, tt, tp])
tr = relay.TypeRelation(None, args, 2, None)
assert check_kind(tr)
def test_invalid_tuple_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType) tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
...@@ -38,6 +53,7 @@ def test_invalid_tuple_kinds(): ...@@ -38,6 +53,7 @@ def test_invalid_tuple_kinds():
tup_ty = relay.TupleType(fields) tup_ty = relay.TupleType(fields)
assert not check_kind(tup_ty) assert not check_kind(tup_ty)
def test_invalid_func_kind(): def test_invalid_func_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType) tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
...@@ -51,16 +67,29 @@ def test_invalid_func_kind(): ...@@ -51,16 +67,29 @@ def test_invalid_func_kind():
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert not check_kind(tf) assert not check_kind(tf)
def test_invalid_relation_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
args = tvm.convert([tp1, tp2, tp3])
tr = relay.TypeRelation(None, args, 2, None)
assert not check_kind(tr)
def test_func_with_invalid_ret_type(): def test_func_with_invalid_ret_type():
tp1 = relay.TypeParam('tp1', relay.Kind.Type) tp1 = relay.TypeParam('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Shape) tp2 = relay.TypeParam('tp2', relay.Kind.Shape)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_func_with_invalid_arg_types(): def test_func_with_invalid_arg_types():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.Type) tp2 = relay.TypeParam('tp2', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_func_with_invalid_tuple(): def test_func_with_invalid_tuple():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
...@@ -69,6 +98,18 @@ def test_func_with_invalid_tuple(): ...@@ -69,6 +98,18 @@ def test_func_with_invalid_tuple():
tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([])) tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([]))
assert not check_kind(tf) assert not check_kind(tf)
def test_func_with_invalid_relation():
tp1 = relay.TypeParam('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Shape)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None)
tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr]))
assert not check_kind(tf)
def test_tuple_with_invalid_func(): def test_tuple_with_invalid_func():
tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
...@@ -77,3 +118,17 @@ def test_tuple_with_invalid_func(): ...@@ -77,3 +118,17 @@ def test_tuple_with_invalid_func():
tup_ty = relay.TupleType(tvm.convert([tensor_type, tf])) tup_ty = relay.TupleType(tvm.convert([tensor_type, tf]))
assert not check_kind(tup_ty) assert not check_kind(tup_ty)
if __name__ == "__main__":
test_tuple_kind()
test_func_kind()
test_relation_kind()
test_invalid_tuple_kind()
test_invalid_func_kind()
test_invalid_relation_kind()
test_func_with_invalid_ret_type()
test_func_with_invalid_arg_types()
test_func_with_invalid_tuple()
test_func_with_invalid_relation()
test_tuple_with_invalid_func()
...@@ -58,6 +58,21 @@ def test_tuple_type(): ...@@ -58,6 +58,21 @@ def test_tuple_type():
assert tup_ty.fields == fields assert tup_ty.fields == fields
def test_type_relation():
tp = relay.TypeParam('tp', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
args = tvm.convert([tf, tt, tp])
num_inputs = 2
func = None
attrs = None
tr = relay.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
assert tr.num_inputs == num_inputs
def test_constant(): def test_constant():
arr = tvm.nd.array(10) arr = tvm.nd.array(10)
const = relay.Constant(arr) const = relay.Constant(arr)
...@@ -158,6 +173,8 @@ if __name__ == "__main__": ...@@ -158,6 +173,8 @@ if __name__ == "__main__":
test_tensor_type() test_tensor_type()
test_type_param() test_type_param()
test_func_type() test_func_type()
test_tuple_type()
test_type_relation()
test_constant() test_constant()
test_tuple() test_tuple()
test_local_var() test_local_var()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment