import tvm from tvm import relay from tvm.relay.ir_pass import check_kind def test_tuple_kind(): # only contain type kinds tp = relay.TypeVar('tp', relay.Kind.Type) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) fields = tvm.convert([tp, tf, tt]) tup_ty = relay.TupleType(fields) assert check_kind(tup_ty) def test_func_kind(): # only contain type kinds tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Type) shape = tvm.convert([1, 2, 3]) dtype = 'float32' tensor_type = relay.TensorType(shape, dtype) tr = relay.TypeRelation(None, tvm.convert([tensor_type, tp1]) , 1, None) type_params = tvm.convert([tp1, tp2]) type_constraints = tvm.convert([tr]) arg_types = tvm.convert([tp1, tensor_type]) ret_type = relay.TupleType(tvm.convert([tp2, tensor_type])) tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) assert check_kind(tf) def test_relation_kind(): # only have type kinds for arguments tp = relay.TypeVar('tp', relay.Kind.Type) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) args = tvm.convert([tf, tt, tp]) tr = relay.TypeRelation(None, args, 2, None) assert check_kind(tr) def test_invalid_tuple_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) fields = tvm.convert([tp1, tp2, tp3]) tup_ty = relay.TupleType(fields) assert not check_kind(tup_ty) def test_invalid_func_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) type_params = tvm.convert([tp1, tp2, tp3]) type_constraints = tvm.convert([]) arg_types = tvm.convert([tp1, tp2]) ret_type = tp3 tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) assert not check_kind(tf) def test_invalid_relation_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) args = tvm.convert([tp1, tp2, tp3]) tr = relay.TypeRelation(None, args, 2, None) assert not check_kind(tr) def test_func_with_invalid_ret_type(): tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) def test_func_with_invalid_arg_types(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.Type) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) def test_func_with_invalid_tuple(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1])) tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([])) assert not check_kind(tf) def test_func_with_invalid_relation(): tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None) tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr])) assert not check_kind(tf) def test_tuple_with_invalid_func(): tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([])) tup_ty = relay.TupleType(tvm.convert([tensor_type, tf])) assert not check_kind(tup_ty) if __name__ == "__main__": test_tuple_kind() test_func_kind() test_relation_kind() test_invalid_tuple_kind() test_invalid_func_kind() test_invalid_relation_kind() test_func_with_invalid_ret_type() test_func_with_invalid_arg_types() test_func_with_invalid_tuple() test_func_with_invalid_relation() test_tuple_with_invalid_func()