Commit be77cf19 by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Restore kind checking (#1758)

parent d59320c7
......@@ -15,6 +15,7 @@ Span = base.Span
# Type
Type = ty.Type
TupleType = ty.TupleType
TensorType = ty.TensorType
Kind = ty.Kind
TypeParam = ty.TypeParam
......
......@@ -12,3 +12,5 @@ from . import _ir_pass
check_expr = _ir_pass.check_expr
well_formed = _ir_pass.well_formed
check_kind = _ir_pass.check_kind
......@@ -21,7 +21,6 @@ class Type(NodeBase):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details.
......@@ -95,6 +94,27 @@ class TypeConstraint(Type):
@register_relay_node
class TupleType(Type):
"""A tuple type in Relay, see tvm/relay/type.h for more details.
Lists the type of each field in the tuple.
"""
def __init__(self, fields):
"""Constructs a tuple type
Parameters
----------
fields: list of tvm.Type
Returns
-------
tuple_type: the tuple type
"""
self.__init_handle_by_constructor__(_make.TupleType, fields)
@register_relay_node
class FuncType(Type):
"""A function type in Relay, see tvm/relay/type.h for more details.
......
......@@ -95,6 +95,27 @@ class TypeConstraint(Type):
@register_relay_node
class TupleType(Type):
"""A tuple type in Relay, see tvm/relay/type.h for more details.
Lists the type of each field in the tuple.
"""
def __init__(self, fields):
"""Constructs a tuple type
Parameters
----------
fields: list of tvm.Type
Returns
-------
tuple_type: the tuple type
"""
self.__init_handle_by_constructor__(_make.TupleType, fields)
@register_relay_node
class FuncType(Type):
"""A function type in Relay, see tvm/relay/type.h for more details.
......
......@@ -20,12 +20,72 @@ namespace tvm {
namespace relay {
using namespace tvm::runtime;
using Kind = TypeParamNode::Kind;
struct KindChecker : TypeVisitor<> {
bool valid;
KindChecker() : valid(true) {}
// checks if t is an incomplete node of kind k or a type param of kind k
bool MatchKind(const Type& t, Kind k) {
if (const IncompleteTypeNode *tv = t.as<IncompleteTypeNode>()) {
return tv->kind == k;
}
if (const TypeParamNode *tp = t.as<TypeParamNode>()) {
return tp->kind == k;
}
return false;
}
bool IsTypeKind(const Type& t) {
if (MatchKind(t, Kind::kType)) {
return true;
}
return t.as<TensorTypeNode>() || t.as<BaseTensorTypeNode>()
|| t.as<TupleTypeNode>() || t.as<FuncTypeNode>();
}
void VisitType_(const TupleTypeNode* op) override {
// tuples should only contain normal types
for (const Type& t : op->fields) {
this->VisitType(t);
valid = valid && IsTypeKind(t);
if (!valid) {
return;
}
}
}
void VisitType_(const FuncTypeNode* op) override {
// func types should only take normal types for arguments
// and only return a normal type
for (const Type& t : op->arg_types) {
this->VisitType(t);
valid = valid && IsTypeKind(t);
if (!valid) {
return;
}
}
this->VisitType(op->ret_type);
valid = valid && IsTypeKind(op->ret_type);
}
void VisitType_(const TypeRelationNode* op) override {
// arguments to type relation should be normal types
for (const Type& t : op->args) {
this->VisitType(t);
valid = valid && IsTypeKind(t);
if (!valid) {
return;
}
}
}
bool Check(const Type &t) {
this->VisitType(t);
return valid;
......@@ -37,5 +97,14 @@ bool KindCheck(const Environment& env, const Type &t) {
return kc.Check(t);
}
TVM_REGISTER_API("relay._ir_pass.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = KindCheck(EnvironmentNode::make({}), args[0]);
} else {
*ret = KindCheck(args[0], args[1]);
}
});
} // namespace relay
} // namespace tvm
import tvm
from tvm import relay
from tvm.relay.ir_pass import check_kind
def test_tuple_kinds():
# only contain type kinds
tp = relay.TypeParam('tp', relay.Kind.Type)
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([]))
fields = tvm.convert([tp, tf, tt])
tup_ty = relay.TupleType(fields)
assert check_kind(tup_ty)
def test_func_kind():
# only contain type kinds
tp1 = relay.TypeParam('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Type)
shape = tvm.convert([1, 2, 3])
dtype = 'float32'
tensor_type = relay.TensorType(shape, dtype)
type_params = tvm.convert([tp1, tp2])
type_constraints = tvm.convert([])
arg_types = tvm.convert([tp1, tensor_type])
ret_type = relay.TupleType(tvm.convert([tp2, tensor_type]))
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert check_kind(tf)
def test_invalid_tuple_kinds():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
fields = tvm.convert([tp1, tp2, tp3])
tup_ty = relay.TupleType(fields)
assert not check_kind(tup_ty)
def test_invalid_func_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
type_params = tvm.convert([tp1, tp2, tp3])
type_constraints = tvm.convert([])
arg_types = tvm.convert([tp1, tp2])
ret_type = tp3
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert not check_kind(tf)
def test_func_with_invalid_ret_type():
tp1 = relay.TypeParam('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Shape)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_func_with_invalid_arg_types():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_func_with_invalid_tuple():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1]))
tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([]))
assert not check_kind(tf)
def test_tuple_with_invalid_func():
tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([]))
tup_ty = relay.TupleType(tvm.convert([tensor_type, tf]))
assert not check_kind(tup_ty)
......@@ -28,8 +28,8 @@ def test_tensor_type():
def test_type_param():
tp = relay.TypeParam('name', relay.Kind.Shape)
tp.kind == relay.Kind.Shape
tp.span # TODO allow us to set span
assert tp.kind == relay.Kind.Shape
# assert tp.span # TODO allow us to set span
str(tp)
......@@ -48,6 +48,16 @@ def test_func_type():
str(tf)
def test_tuple_type():
tp = relay.TypeParam('tp', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
fields = tvm.convert([tp, tf, tt])
tup_ty = relay.TupleType(fields)
assert tup_ty.fields == fields
def test_constant():
arr = tvm.nd.array(10)
const = relay.Constant(arr)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment