kind_check.cc 6.25 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30
/*!
 *
 * \file kindchecker.cc
 *
 * \brief Check that types are well formed by applying "kinding rules".
 *
 * This pass ensures we do not do things that violate the design of the
 * type system when writing down types.
 *
 * For example tensors are not allowed to contain functions in Relay.
 *
31
 * We check this by ensuring the `dtype` field of a Tensor always
32 33
 * contains a data type such as `int`, `float`, `uint`.
 */
Zhi committed
34
#include <tvm/relay/analysis.h>
35
#include <tvm/relay/error.h>
36
#include "../ir/type_functor.h"
37 38 39 40 41 42

namespace tvm {
namespace relay {

using namespace tvm::runtime;

43 44 45
struct KindChecker : TypeFunctor<Kind(const Type&)> {
  const Module& mod;
  ErrorReporter err_reporter;
46

47
  explicit KindChecker(const Module& mod) : mod(mod), err_reporter() {}
48

49 50 51 52
  void ReportFatalError(const Error& err) {
    this->err_reporter.Report(err);
    this->err_reporter.RenderErrors(mod);
  }
53

54 55 56 57 58 59 60 61 62
  void CheckKindMatches(const Type& t, const Type& outer,
                        Kind expected, const std::string& description) {
    Kind k = this->VisitType(t);
    if (k != expected) {
      ReportFatalError(RELAY_ERROR("Incorrect kind for a " << description
                                   << ". Type " << t << " inside " << outer
                                   << " is of kind " << k
                                   << " but was expected to be "
                                   << expected));
63
    }
64
  }
65

66 67
  Kind VisitType_(const IncompleteTypeNode* op) override {
    return op->kind;
68 69
  }

70 71 72 73 74 75 76
  Kind VisitType_(const TypeVarNode* op) override {
    return op->kind;
  }

  Kind VisitType_(const GlobalTypeVarNode* op) override {
    return op->kind;
  }
77

78 79
  Kind VisitType_(const TensorTypeNode* op) override {
    return Kind::kType;
80 81
  }

82
  Kind VisitType_(const TupleTypeNode* op) override {
83 84
    // tuples should only contain normal types
    for (const Type& t : op->fields) {
85 86
      CheckKindMatches(t, GetRef<TupleType>(op), Kind::kType,
                       "tuple member");
87
    }
88
    return Kind::kType;
89 90
  }

91
  Kind VisitType_(const FuncTypeNode* op) override {
92 93 94
    // Func types should only take normal types for arguments
    // and only return a normal type. They should also have
    // well-formed constraints
95
    FuncType ft = GetRef<FuncType>(op);
96
    for (const Type& t : op->arg_types) {
97
      CheckKindMatches(t, ft, Kind::kType, "function type parameter");
98 99
    }

100 101
    CheckKindMatches(ft->ret_type, ft, Kind::kType, "function return type");

102
    for (const TypeConstraint& tc : op->type_constraints) {
103
      CheckKindMatches(tc, ft, Kind::kConstraint, "function type constraint");
104 105
    }

106
    return Kind::kType;
107 108
  }

109 110 111 112 113
  Kind VisitType_(const RefTypeNode* op) override {
    // ref types should only contain normal types
    RefType rt = GetRef<RefType>(op);
    CheckKindMatches(op->value, rt, Kind::kType, "ref contents");
    return Kind::kType;
114 115
  }

116
  Kind VisitType_(const TypeRelationNode* op) override {
117 118
    // arguments to type relation should be normal types
    for (const Type& t : op->args) {
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
      CheckKindMatches(t, GetRef<TypeRelation>(op), Kind::kType,
                       "argument to type relation");
    }
    return Kind::kConstraint;
  }

  Kind VisitType_(const TypeCallNode* op) override {
    // type call func should be a global type var, args should be type
    TypeCall tc = GetRef<TypeCall>(op);
    const auto* gtv = op->func.as<GlobalTypeVarNode>();
    if (gtv == nullptr) {
      ReportFatalError(RELAY_ERROR("The callee in " << tc
                                   << " is not a global type var, but is " << op->func));
    }

    CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function");

    for (const Type& t : op->args) {
      CheckKindMatches(t, tc, Kind::kType, "type call argument");
    }

    // finally we need to check the module to check the number of type params
    auto var = GetRef<GlobalTypeVar>(gtv);
    auto data = mod->LookupDef(var);
    if (data->type_vars.size() != op->args.size()) {
      ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc
                                   << "; got " << op->args.size()));
    }
    return Kind::kType;
  }

  Kind VisitType_(const TypeDataNode* op) override {
    // Constructors can reference the header var, but no other GlobalTypeVars.
    // In theory, a TypeData could be nested, so the header scope
    // should be tracked recursively, but it is unclear that we need
    // to support it.
    TypeData td = GetRef<TypeData>(op);
    CheckKindMatches(op->header, td, Kind::kAdtHandle, "type data header");

    for (const auto& var : op->type_vars) {
      CheckKindMatches(var, td, Kind::kType, "ADT type var");
    }

    for (const auto& con : op->constructors) {
      if (!con->belong_to.same_as(op->header)) {
        ReportFatalError(RELAY_ERROR(con << " has header " << con->belong_to
165
                                     << " but " << op << " has header " << op->header));
166 167 168 169
      }

      for (const Type& t : con->inputs) {
        CheckKindMatches(t, td, Kind::kType, "ADT constructor input");
170 171
      }
    }
172
    return Kind::kTypeData;
173 174
  }

175 176
  Kind Check(const Type& t) {
    return this->VisitType(t);
177 178 179
  }
};

180 181
Kind KindCheck(const Type& t, const Module& mod) {
  KindChecker kc(mod);
182 183 184
  return kc.Check(t);
}

Zhi committed
185
TVM_REGISTER_API("relay._analysis.check_kind")
186 187
.set_body([](TVMArgs args, TVMRetValue* ret) {
    if (args.size() == 1) {
188
      *ret = KindCheck(args[0], ModuleNode::make({}, {}));
189 190 191 192
    } else {
      *ret = KindCheck(args[0], args[1]);
    }
  });
193

194 195
}  // namespace relay
}  // namespace tvm