util.cc 11.7 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 *
 * \file util.cc
 *
24
 * \brief Utility functions for Relay.
25
 */
26
#include <tvm/ir/type_functor.h>
Zhi committed
27
#include <tvm/relay/analysis.h>
28
#include <tvm/relay/expr_functor.h>
29
#include <tvm/relay/op.h>
30
#include <tvm/relay/pattern_functor.h>
31
#include "../transforms/pass_util.h"
32 33 34 35

namespace tvm {
namespace relay {

36 37
template<typename T>
struct InsertionSet {
38
  std::unordered_set<T, ObjectHash, ObjectEqual> set;
39 40 41 42 43 44 45 46 47 48
  std::vector<T> data;
  void Insert(const T& t) {
    if (set.count(t) == 0) {
      set.insert(t);
      data.push_back(t);
    }
  }
};

class TypeVarTVisitor : public TypeVisitor {
49
 public:
50 51 52
  TypeVarTVisitor(
      InsertionSet<TypeVar>* type_vars,
      InsertionSet<TypeVar>* bound_type_vars)
53
    : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { }
54

55
  void VisitType_(const TypeVarNode* tp) final {
56
    TypeVar var = GetRef<TypeVar>(tp);
57
    type_vars_->Insert(var);
58 59 60 61
  }

  void VisitType_(const FuncTypeNode* f) final {
    for (auto type_param : f->type_params) {
62 63
      type_vars_->Insert(type_param);
      bound_type_vars_->Insert(type_param);
64
    }
65 66
    TypeVisitor::VisitType_(f);
  }
67

68
 private:
69 70
  InsertionSet<TypeVar>* type_vars_;
  InsertionSet<TypeVar>* bound_type_vars_;
71
};
72

73
class TypeVarEVisitor : private ExprVisitor {
74
 public:
75
  explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {}
76

77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
  Array<TypeVar> CollectFree() {
    Array<TypeVar> ret;
    for (const auto& v : type_vars_.data) {
      if (bound_type_vars_.set.count(v) == 0) {
        ret.push_back(v);
      }
    }
    return ret;
  }

  Array<TypeVar> CollectBound() {
    Array<TypeVar> ret;
    for (const auto& v : bound_type_vars_.data) {
      ret.push_back(v);
    }
    return ret;
  }

  Array<TypeVar> CollectAll() {
    Array<TypeVar> ret;
    for (const auto& v : type_vars_.data) {
      ret.push_back(v);
    }
    return ret;
101 102
  }

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
  Array<TypeVar> Free(const Expr& expr) {
    VisitExpr(expr);
    return CollectFree();
  }

  Array<TypeVar> Free(const Type& type) {
    VisitType(type);
    return CollectFree();
  }

  Array<TypeVar> Bound(const Expr& expr) {
    VisitExpr(expr);
    return CollectBound();
  }

  Array<TypeVar> Bound(const Type& type) {
    VisitType(type);
    return CollectBound();
  }

  Array<TypeVar> All(const Expr& expr) {
    VisitExpr(expr);
    return CollectAll();
  }

  Array<TypeVar> All(const Type& type) {
    VisitType(type);
    return CollectAll();
131 132
  }

133
  void VisitExpr_(const FunctionNode* f) final {
134
    for (const auto& tp : f->type_params) {
135 136
      type_vars_.Insert(tp);
      bound_type_vars_.Insert(tp);
137
    }
138
    ExprVisitor::VisitExpr_(f);
139 140
  }

141 142
  void VisitExpr_(const ConstructorNode* cn) final {
    // for constructors, type vars will be bound in the module
143
    auto data = mod_->LookupTypeDef(cn->belong_to);
144 145 146 147 148 149 150
    for (const auto& tv : data->type_vars) {
      type_vars_.Insert(tv);
      bound_type_vars_.Insert(tv);
    }
    ExprVisitor::VisitExpr_(cn);
  }

151
  void VisitType(const Type& t) final {
152
    TypeVarTVisitor(&type_vars_, &bound_type_vars_)
153
        .VisitType(t);
154 155
  }

156
 private:
157 158
  InsertionSet<TypeVar> type_vars_;
  InsertionSet<TypeVar> bound_type_vars_;
159
  const IRModule& mod_;
160 161
};

162
class VarVisitor : protected ExprVisitor, protected PatternVisitor {
163
 public:
164
  Array<Var> Free(const Expr& expr) {
165
    this->VisitExpr(expr);
166 167 168 169 170 171 172
    Array<Var> ret;
    for (const auto& v : vars_.data) {
      if (bound_vars_.set.count(v) == 0) {
        ret.push_back(v);
      }
    }
    return ret;
173 174
  }

雾雨魔理沙 committed
175
  Array<Var> Collect() {
176 177 178
    Array<Var> ret;
    for (const auto& v : bound_vars_.data) {
      ret.push_back(v);
179
    }
180 181 182
    return ret;
  }

雾雨魔理沙 committed
183 184 185 186 187 188 189 190 191 192
  Array<Var> Bound(const Expr& expr) {
    this->VisitExpr(expr);
    return Collect();
  }

  Array<Var> Bound(const Pattern& pat) {
    this->VisitPattern(pat);
    return Collect();
  }

193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
  Array<Var> All(const Expr& expr) {
    this->VisitExpr(expr);
    Array<Var> ret;
    for (const auto& v : vars_.data) {
      ret.push_back(v);
    }
    return ret;
  }

  void MarkBounded(const Var& v) {
    bound_vars_.Insert(v);
    vars_.Insert(v);
  }

  void VisitExpr_(const VarNode* var) final {
    vars_.Insert(GetRef<Var>(var));
209
  }
210 211 212

  void VisitExpr_(const FunctionNode* op) final {
    for (const auto& param : op->params) {
213
      MarkBounded(param);
214 215 216 217 218
    }
    VisitExpr(op->body);
  }

  void VisitExpr_(const LetNode* op) final {
219
    MarkBounded(op->var);
220 221 222 223
    VisitExpr(op->value);
    VisitExpr(op->body);
  }

224 225 226 227 228 229 230 231
  void VisitPattern(const Pattern& p) final {
    PatternVisitor::VisitPattern(p);
  }

  void VisitPattern_(const PatternVarNode* op) final {
    MarkBounded(op->var);
  }

232
 private:
233 234
  InsertionSet<Var> vars_;
  InsertionSet<Var> bound_vars_;
235 236
};

237
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const IRModule& mod) {
238
  return TypeVarEVisitor(mod).Free(expr);
239 240
}

241
tvm::Array<TypeVar> FreeTypeVars(const Type& type, const IRModule& mod) {
242
  return TypeVarEVisitor(mod).Free(type);
243 244
}

245
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const IRModule& mod) {
246
  return TypeVarEVisitor(mod).Bound(expr);
247 248
}

249
tvm::Array<TypeVar> BoundTypeVars(const Type& type, const IRModule& mod) {
250
  return TypeVarEVisitor(mod).Bound(type);
251 252
}

253
tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod) {
254
  return TypeVarEVisitor(mod).All(expr);
255 256
}

257
tvm::Array<TypeVar> AllTypeVars(const Type& type, const IRModule& mod) {
258
  return TypeVarEVisitor(mod).All(type);
259 260
}

261
tvm::Array<Var> FreeVars(const Expr& expr) {
262 263 264 265 266 267 268
  return VarVisitor().Free(expr);
}

tvm::Array<Var> BoundVars(const Expr& expr) {
  return VarVisitor().Bound(expr);
}

雾雨魔理沙 committed
269 270 271 272
tvm::Array<Var> BoundVars(const Pattern& pat) {
  return VarVisitor().Bound(pat);
}

273 274
tvm::Array<Var> AllVars(const Expr& expr) {
  return VarVisitor().All(expr);
275 276
}

277
TVM_REGISTER_GLOBAL("relay.analysis.free_vars")
278
.set_body_typed(FreeVars);
279

280
TVM_REGISTER_GLOBAL("relay.analysis.bound_vars")
281
  .set_body([](TVMArgs args, TVMRetValue* ret) {
282
      ObjectRef x = args[0];
283
      if (x.as<ExprNode>()) {
雾雨魔理沙 committed
284 285 286 287
        *ret = BoundVars(Downcast<Expr>(x));
      } else {
        *ret = BoundVars(Downcast<Pattern>(x));
      }
288 289
    });

290
TVM_REGISTER_GLOBAL("relay.analysis.all_vars")
291
.set_body_typed(AllVars);
292

293
TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars")
294
.set_body([](TVMArgs args, TVMRetValue* ret) {
295
    ObjectRef x = args[0];
296
    IRModule mod = args[1];
297
    if (x.as<TypeNode>()) {
298
      *ret = FreeTypeVars(Downcast<Type>(x), mod);
299
    } else {
300
      *ret = FreeTypeVars(Downcast<Expr>(x), mod);
301 302 303
    }
  });

304
TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars")
305
  .set_body([](TVMArgs args, TVMRetValue* ret) {
306
      ObjectRef x = args[0];
307
      IRModule mod = args[1];
308
      if (x.as<TypeNode>()) {
309
        *ret = BoundTypeVars(Downcast<Type>(x), mod);
310
      } else {
311
        *ret = BoundTypeVars(Downcast<Expr>(x), mod);
312 313 314
      }
    });

315
TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars")
316
  .set_body([](TVMArgs args, TVMRetValue* ret) {
317
      ObjectRef x = args[0];
318
      IRModule mod = args[1];
319
      if (x.as<TypeNode>()) {
320
        *ret = AllTypeVars(Downcast<Type>(x), mod);
321
      } else {
322
        *ret = AllTypeVars(Downcast<Expr>(x), mod);
323 324 325
      }
    });

326 327 328 329 330
/*!
 * \brief Get reference counter of each internal ExprNode in body.
 * \param body The body expression.
 * \return The reference count mapping.
 */
331
std::unordered_map<const Object*, size_t>
332
GetExprRefCount(const Expr& body) {
333
  class ExprRefCounter : private MixedModeVisitor {
334
   public:
335
    std::unordered_map<const Object*, size_t>
336 337 338 339 340 341 342 343
    Get(const Expr& body) {
      this->VisitExpr(body);
      return std::move(this->visit_counter_);
    }
  };
  return ExprRefCounter().Get(body);
}

344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
template <typename T>
bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
  CHECK_EQ(tensor->ctx.device_type, kDLCPU);
  CHECK(tensor->strides == nullptr);
  CHECK_EQ(tensor->byte_offset, 0);
  const T* data = static_cast<const T*>(tensor->data);
  int64_t num_elems = 1;
  for (int i = 0; i < tensor->ndim; ++i) {
    num_elems *= tensor->shape[i];
  }

  for (int64_t i = 0; i < num_elems; i++) {
    if (*data < value) {
      return false;
    }
    data++;
  }
  return true;
}

364 365 366 367 368 369
// Cache the operators that are checked recursively to reduce lookup overhead.
static const auto& expand_dims_op = Op::Get("expand_dims");
static const auto& reshape_op = Op::Get("reshape");
static const auto& transpose_op = Op::Get("transpose");
static const auto& squeeze_op = Op::Get("squeeze");

370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
bool IsAllPositiveConstant(const Expr& expr) {
  // peel through a few common transform ops.
  if (const auto* constant = expr.as<ConstantNode>()) {
    const auto& tensor = constant->data;
    const auto& dtype = tensor->dtype;
    if (dtype.lanes != 1) {
      return false;
    } else if (dtype.code == kDLFloat && dtype.bits == 32) {
      return IsNDArrayAllGreaterEqual<float>(tensor, 0);
    } else if (dtype.code == kDLFloat && dtype.bits == 64) {
      return IsNDArrayAllGreaterEqual<double>(tensor, 0);
    } else if (dtype.code == kDLInt && dtype.bits == 8) {
      return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
    } else if (dtype.code == kDLInt && dtype.bits == 32) {
      return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
    } else if (dtype.code == kDLUInt && dtype.bits == 8) {
      return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
    } else if (dtype.code == kDLUInt && dtype.bits == 32) {
      return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
    } else {
      return false;
    }
  } else if (const auto* op = expr.as<CallNode>()) {
    // tail recursion.
394 395 396 397
    if (op->op == expand_dims_op ||
        op->op == reshape_op ||
        op->op == transpose_op ||
        op->op == squeeze_op) {
398 399 400 401 402 403 404 405 406
      return IsAllPositiveConstant(op->args[0]);
    } else {
      return false;
    }
  } else {
    return false;
  }
}

雾雨魔理沙 committed
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428
Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst) {
  return TypeSubst(type, tvm::Map<TypeVar, Type>({{tvar, subst}}));
}

Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst) {
  return TypeSubst(expr, tvm::Map<TypeVar, Type>({{tvar, subst}}));
}

Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map) {
  return Bind(type, subst_map);
}

Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
  class TypeSubstMutator : public ExprMutator, public PatternMutator {
   public:
    explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) { }
    Type VisitType(const Type& t) final {
      return TypeSubst(t, subst_map_);
    }
    Var VisitVar(const Var& v) final {
      return Downcast<Var>(VisitExpr(v));
    }
429 430 431 432 433 434 435

    Pattern VisitPattern(const Pattern& p) final {
      return PatternMutator::VisitPattern(p);
    }

    Clause VisitClause(const Clause& c) final {
      Pattern pat = VisitPattern(c->lhs);
436
      return Clause(pat, VisitExpr(c->rhs));
437 438
    }

雾雨魔理沙 committed
439 440 441
   private:
    const tvm::Map<TypeVar, Type>& subst_map_;
  };
442 443 444 445 446
  CHECK(WellFormed(expr));
  auto ret = TypeSubstMutator(subst_map).VisitExpr(expr);
  CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
  CHECK(WellFormed(ret));
  return ret;
雾雨魔理沙 committed
447 448
}

449 450
}  // namespace relay
}  // namespace tvm