kind_check.cc 2.75 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/*!
 * Copyright (c) 2018 by Contributors
 *
 * \file kindchecker.cc
 *
 * \brief Check that types are well formed by applying "kinding rules".
 *
 * This pass ensures we do not do things that violate the design of the
 * type system when writing down types.
 *
 * For example tensors are not allowed to contain functions in Relay.
 *
13
 * We check this by ensuring the `dtype` field of a Tensor always
14 15 16
 * contains a data type such as `int`, `float`, `uint`.
 */
#include <tvm/relay/pass.h>
17
#include "../ir/type_functor.h"
18 19 20 21 22

namespace tvm {
namespace relay {

using namespace tvm::runtime;
23
using Kind = TypeVarNode::Kind;
24

25
struct KindChecker : TypeVisitor {
26 27 28 29
  bool valid;

  KindChecker() : valid(true) {}

30 31
  // checks if t is an incomplete node of kind k or a type param of kind k
  bool MatchKind(const Type& t, Kind k) {
32
    if (const IncompleteTypeNode* tv = t.as<IncompleteTypeNode>()) {
33 34 35
      return tv->kind == k;
    }

36
    if (const TypeVarNode* tp = t.as<TypeVarNode>()) {
37 38 39 40 41 42 43 44 45 46 47
      return tp->kind == k;
    }

    return false;
  }

  bool IsTypeKind(const Type& t) {
    if (MatchKind(t, Kind::kType)) {
      return true;
    }

48
    return t.as_derived<BaseTensorTypeNode>() || t.as<TupleTypeNode>() || t.as<FuncTypeNode>();
49 50 51 52 53 54 55 56 57 58 59 60 61 62
  }

  void VisitType_(const TupleTypeNode* op) override {
    // tuples should only contain normal types
    for (const Type& t : op->fields) {
      this->VisitType(t);
      valid = valid && IsTypeKind(t);
      if (!valid) {
        return;
      }
    }
  }

  void VisitType_(const FuncTypeNode* op) override {
63 64 65
    // Func types should only take normal types for arguments
    // and only return a normal type. They should also have
    // well-formed constraints
66 67 68 69 70 71 72 73
    for (const Type& t : op->arg_types) {
      this->VisitType(t);
      valid = valid && IsTypeKind(t);
      if (!valid) {
        return;
      }
    }

74 75 76 77 78 79 80
    for (const TypeConstraint& tc : op->type_constraints) {
      this->VisitType(tc);
      if (!valid) {
        return;
      }
    }

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    this->VisitType(op->ret_type);
    valid = valid && IsTypeKind(op->ret_type);
  }

  void VisitType_(const TypeRelationNode* op) override {
    // arguments to type relation should be normal types
    for (const Type& t : op->args) {
      this->VisitType(t);
      valid = valid && IsTypeKind(t);
      if (!valid) {
        return;
      }
    }
  }

96
  bool Check(const Type& t) {
97 98 99 100 101
    this->VisitType(t);
    return valid;
  }
};

102
bool KindCheck(const Type& t, const Module& mod) {
103 104 105 106
  KindChecker kc;
  return kc.Check(t);
}

107
TVM_REGISTER_API("relay._ir_pass.check_kind")
108 109
.set_body([](TVMArgs args, TVMRetValue* ret) {
    if (args.size() == 1) {
110
      *ret = KindCheck(args[0], ModuleNode::make({}));
111 112 113 114
    } else {
      *ret = KindCheck(args[0], args[1]);
    }
  });
115

116 117
}  // namespace relay
}  // namespace tvm