/*! * Copyright (c) 2018 by Contributors * * \file kindchecker.cc * * \brief Check that types are well formed by applying "kinding rules". * * This pass ensures we do not do things that violate the design of the * type system when writing down types. * * For example tensors are not allowed to contain functions in Relay. * * We check this by ensuring the `dtype` field of a Tensor always * contains a data type such as `int`, `float`, `uint`. */ #include <tvm/relay/pass.h> #include "../ir/type_functor.h" namespace tvm { namespace relay { using namespace tvm::runtime; using Kind = TypeVarNode::Kind; struct KindChecker : TypeVisitor { bool valid; KindChecker() : valid(true) {} // checks if t is an incomplete node of kind k or a type param of kind k bool MatchKind(const Type& t, Kind k) { if (const IncompleteTypeNode* tv = t.as<IncompleteTypeNode>()) { return tv->kind == k; } if (const TypeVarNode* tp = t.as<TypeVarNode>()) { return tp->kind == k; } return false; } bool IsTypeKind(const Type& t) { if (MatchKind(t, Kind::kType)) { return true; } return t.as_derived<BaseTensorTypeNode>() || t.as<TupleTypeNode>() || t.as<FuncTypeNode>(); } void VisitType_(const TupleTypeNode* op) override { // tuples should only contain normal types for (const Type& t : op->fields) { this->VisitType(t); valid = valid && IsTypeKind(t); if (!valid) { return; } } } void VisitType_(const FuncTypeNode* op) override { // Func types should only take normal types for arguments // and only return a normal type. They should also have // well-formed constraints for (const Type& t : op->arg_types) { this->VisitType(t); valid = valid && IsTypeKind(t); if (!valid) { return; } } for (const TypeConstraint& tc : op->type_constraints) { this->VisitType(tc); if (!valid) { return; } } this->VisitType(op->ret_type); valid = valid && IsTypeKind(op->ret_type); } void VisitType_(const RefTypeNode* op) override { // tuples should only contain normal types this->VisitType(op->value); valid = valid && IsTypeKind(op->value); } void VisitType_(const TypeRelationNode* op) override { // arguments to type relation should be normal types for (const Type& t : op->args) { this->VisitType(t); valid = valid && IsTypeKind(t); if (!valid) { return; } } } bool Check(const Type& t) { this->VisitType(t); return valid; } }; bool KindCheck(const Type& t, const Module& mod) { KindChecker kc; return kc.Check(t); } TVM_REGISTER_API("relay._ir_pass.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { *ret = KindCheck(args[0], ModuleNode::make({})); } else { *ret = KindCheck(args[0], args[1]); } }); } // namespace relay } // namespace tvm