pass.h 17.2 KB
Newer Older
1 2 3 4
/*!
 *  Copyright (c) 2018 by Contributors
 * \file tvm/relay/pass.h
 * \brief The set of Relay passes written in C++.
Zhi committed
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
 *
 * This file also implements a pass manager. The pass manager manages a sequence
 * of Relay-to-Relay transformation passes over a particlar unit of AST. The
 * design is largely inspired from LLVM's pass manager and modern deep learning
 * frameworks that perform tensor->tensor transformations.
 *
 * The responsibilities of a traditional compiler pass manager usually involves:
 *  - Organizing the execution order of optimization passes though not
 * necessarily in the optimal sequence.
 *  - Collecting required analysis information and keep them up-to-date.
 *  - Reducing the effort required to implement new passes for compiler
 * developers, etc.
 *
 * Similar to LLVM's pass manager, we designed the Relay pass manager to work
 * different granularity, i.e. module level, function level, and even sequential
 * passe that contains a host of passes.
 *
 * However, we also extend the functionality of the traditional pass manager
 * with the consideration of requirements/convention from deep learning
 * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
 * manager performs the Relay.Module -> Relay.Module transformation. All
 * different types of passes, including the sequential-level pass object, are
 * essentially pass objects. This design, therefore, effectively provides users
 * a consistent and convenient interface, i.e. Pass, to play with. It offers a
 * means to ease the development and testing of Relay passes. For example, with
 * the pass manager, external users will be able to have custom passes correctly
 * scheduled without having to modify a single handcrafted pass order.
 *
 * In the future we need to describe constraints between passes. For example,
 * we may want to preserve dependencies between different passes and validate
 * them on the completion of a certain pass.
 *
 * We also need to store side information and import the error reporting system.
38 39 40 41
 */
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

Zhi committed
42 43 44
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
45
#include <tvm/relay/expr.h>
46
#include <tvm/relay/module.h>
47
#include <tvm/relay/op_attr_types.h>
Zhi committed
48 49
#include <tvm/relay/type.h>

50
#include <string>
Zhi committed
51
#include <vector>
52 53 54 55

namespace tvm {
namespace relay {

Zhi committed
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
namespace pass {

/*
 * \brief The context of pass.
 */
class PassContext;

/*!
 * \brief PassContextNode contains the information that a pass can rely on, such as
 * analysis results.
 */
class PassContextNode : public RelayNode {
 public:
  /*!
   * \brief The error reporter used to notify users why an optimization fails.
   */
  ErrorReporter err_reporter;

  PassContextNode() = default;

  void VisitAttrs(tvm::AttrVisitor* v) final {
  }

  TVM_DLL static PassContext make();

  static constexpr const char* _type_key = "relay.PassContext";
  TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};

TVM_DEFINE_NODE_REF(PassContext, PassContextNode)

/*
 * \brief The meta data of a pass.
 *
 * PassInfo can be extended conveniently in the future if more meta information
 * is needed.
 */
class PassInfo;

/*!
 * \brief PassInfoNode contains meta data that will be used to help optimization
 * and analysis.
 */
class PassInfoNode : public RelayNode {
 public:
  /*! \brief The minimal optimization level that this pass will be enabled. */
  int opt_level;

  /*! \brief The name of an optimization/analysis pass. */
  std::string name;

  /*! \brief The passes that are required to perform the current pass. */
  tvm::Array<tvm::Expr> required;

  PassInfoNode() = default;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("opt_level", &opt_level);
    v->Visit("name", &name);
    v->Visit("required", &required);
  }

  TVM_DLL static PassInfo make(int opt_level, std::string name,
                               tvm::Array<tvm::Expr> required);

  static constexpr const char* _type_key = "relay.PassInfo";
  TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
};

TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)

class Pass;

/*!
 * \brief PassNode is the base type of differnt types of optimization passes.
 * It is designed as a pure class and implemented by different pass subclasses
 * at different granularity of Relay nodes.
 */
class PassNode : public RelayNode {
 public:
  /*
   * \brief Get the pass information/meta data. */
  virtual PassInfo Info() const = 0;

  /*!
   * \brief Set the context information for a pass.
   *
   * \param pass_ctx The context information for a certain pass.
   */
  virtual void SetContext(const PassContext& pass_ctx) = 0;

  /*!
   * \brief Execute the optimization pass using a functor.
   *
   * \param mod The module that an optimization pass runs on.
   *
   * \return The updated module.
   */
  virtual Module operator()(const Module& mod) const = 0;

  void VisitAttrs(tvm::AttrVisitor* v) override {}

  static constexpr const char* _type_key = "relay.Pass";
  TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
};

class Pass : public NodeRef {
 public:
  Pass() = default;
  explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}

  PassNode* operator->() const {
    return static_cast<PassNode*>(this->node_.get());
  }

  using ContainerType = PassNode;
};

/*
 * \brief Create a module pass.
 *
 * \param pass_func The packed function that contains the optimization.
 * \param opt_level The optimization level of the module pass.
 * \param name The name of the module pass.
 * \param required The list of the passes that the module pass is dependent on.
 *
 * \return The created module pass.
 */
Pass CreateModulePass(
    const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
    int opt_level,
    const std::string& name,
    const tvm::Array<tvm::Expr>& required);

/*
 * \brief Create a function pass.
 *
 * \param pass_func The packed function that contains the optimization.
 * \param opt_level The optimization level of the function pass.
 * \param name The name of the function pass.
 * \param required The list of the passes that the function pass is dependent on.
 *
 * \return The created function pass.
 */
Pass CreateFunctionPass(
    const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
    int opt_level,
    const std::string& name,
    const tvm::Array<tvm::Expr>& required);
/*
 * \brief Create a sequential pass.
 *
 * \param passes The optimization passes will be performed.
 * \param opt_level The optimization level of the sequential pass.
 * \param name The name of the sequential pass.
 * \param required The list of the passes that the sequential pass is dependent on.
 * \param disabled The disabled passes.
 *
 * \return The created sequential pass.
 */
Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
                          int opt_level,
                          const std::string& name,
                          const tvm::Array<tvm::Expr>& required,
                          const tvm::Array<tvm::Expr>& disabled);

}  // namespace pass

224 225
/*!
 * \brief Infer the type of an expression.
226 227 228 229 230
 *
 * The result of type checking is a new expression with unambigous
 * type information filled in, as well as it's checked type field
 * populated with the result type.
 *
231
 * \param expr The expression to type check.
232
 * \param mod The module used for referencing global functions, can be
233
 * None.
234 235 236
 *
 * \return A type checked expression with its checked_type field populated.
 */
237
TVM_DLL Expr InferType(const Expr& expr, const Module& mod);
238

239
/*!
240
 * \brief Infer the type of a function as if it is mapped to var in the mod.
241 242
 *
 * \param f the function.
243
 * \param mod The module used for referencing global functions.
244 245 246
 * \param var The global variable corresponding to the function.
 *
 * \return A type checked Function with its checked_type field populated.
247
 * \note this function mutates mod and is not thread-safe.
248
 */
249 250
TVM_DLL Function InferType(const Function& f, const Module& mod,
                           const GlobalVar& var);
251 252

/*!
253
 * \brief Check that types are well kinded by applying "kinding rules".
254 255 256 257 258 259 260 261 262 263
 *
 * This pass ensures we do not do things that violate the design of the
 * type system when writing down types.
 *
 * For example tensors are not allowed to contain functions in Relay.
 *
 * We check this by ensuring the `dtype` field of a Tensor always contains
 * a data type such as `int`, `float`, `uint`.
 *
 * \param t The type to check.
264
 * \param mod The global module.
265
 *
266
 * \return The kind of the passed type.
267
 */
268
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284

/*! \brief Compare two expressions for structural equivalence.
 *
 * This comparison operator respects scoping and compares
 * expressions without regard to variable choice.
 *
 * For example: `let x = 1 in x` is equal to `let y = 1 in y`.
 *
 *   See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
 *   for more details.
 *
 *   \param e1 The left hand expression.
 *   \param e2 The right hand expression.
 *
 *   \return true if equal, otherwise false
 */
285
TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
286 287 288 289 290 291 292 293

/*! \brief Compare two types for structural equivalence.
 *
 * This comparison operator respects scoping and compares
 * expressions without regard to variable choice.
 *
 * For example: `forall s, Tensor[f32, s]` is equal to
 * `forall w, Tensor[f32, w]`.
294
 *
295 296 297 298 299 300 301 302
 * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
 * for more details.
 *
 * \param t1 The left hand type.
 * \param t2 The right hand type.
 *
 * \return true if equal, otherwise false
 */
303
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
304

305
/*! \brief Check that each Var is only bound once.
306 307 308
 *
 * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
 *
309 310
 * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice,
 * although x is not shadowed.
311
 *
312
  * \param expr the expression to check.
313
 *
314
  * \return true iff all Var in expr is bound at most once.
315
 */
316
TVM_DLL bool WellFormed(const Expr& expr);
317

318 319 320 321 322 323 324 325 326
/*! \brief Get all bound variables from expression expr.
 *
 * Bound variables are all variables that are declared in the expr.
 * They only have meaning inside that expr, and can only be used in it.
 *
 * \param expr the expression.
 *
 * \return List of bound vars, in the PostDFS order in the expression.
 */
327
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
328

329
/*! \brief Get free type parameters from expression expr.
330
 *
331 332
 * Free variables are variables that are not bound by a
 * let or a function parameter in the context.
333
 *
334
 * \param expr the expression.
335
 *
336
 * \return List of free vars, in the PostDFS order in the expression.
337
 */
338
TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
339

340 341 342 343 344 345
/*! \brief Get all variables from expression expr.
 *
 * \param expr the expression.
 *
 * \return List of all vars, in the PostDFS order in the expression.
 */
346
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
347

348
/*! \brief Get free TypeVars from expression expr.
349
 *
350 351
 * Free type parameters are type parameters that are not bound by a function
 * type in the context.
352
 *
353
 * \param expr the expression.
354
 * \param mod the module.
355
 *
356
 * \return List of free vars, in the PostDFS order visited by expr.
357
 */
358
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
359

360 361 362 363 364 365
/*! \brief Get free TypeVars from type t.
 *
 * Free type parameters are type parameters that are not bound by a function
 * type in the context.
 *
 * \param t the type.
366
 * \param mod the module.
367 368 369
 *
 * \return List of free type vars, in the PostDFS order visited by type.
 */
370
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
371 372 373 374 375 376 377

/*! \brief Get all bound type variables from expression expr.
 *
 * Bound variables are all type variables that are declared in the expr.
 * They only have meaning inside that expr, and can only be used in it.
 *
 * \param expr the expression.
378
 * \param mod the module.
379 380 381
 *
 * \return List of bound type vars, in the PostDFS order in the expression.
 */
382
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
383 384 385 386 387 388 389

/*! \brief Get all bound type variables from type t.
 *
 * Bound variables are all type variables that are declared in the type.
 * They only have meaning inside that type, and can only be used in it.
 *
 * \param t the type
390
 * \param mod the module.
391 392 393
 *
 * \return List of bound type vars, in the PostDFS order visited by type.
 */
394
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
395 396 397 398

/*! \brief Get all type variables in expression expr.
 *
 * \param expr the expression.
399
 * \param mod the module.
400 401 402
 *
 * \return List of type vars, in the PostDFS order in the expression.
 */
403
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
404 405 406 407

/*! \brief Get all type variables in type t.
 *
 * \param t the type.
408
 * \param mod the module.
409 410 411
 *
 * \return List of type vars, in the PostDFS order visited by type.
 */
412
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
413

414 415
/*! \brief Remove expressions which does not effect the program result.
 *
416 417
 * It will remove let bindings which are not referenced, and branches that will
 * not be entered.
418
 *
419 420 421
 * For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
 * the expression does not depend on a. Another example is `if (true) then 1
 * else 2` will be optimized into 1.
422 423 424 425 426
 *
 * \param e the expression to optimize.
 *
 * \return the optimized expression.
 */
427
TVM_DLL Expr DeadCodeElimination(const Expr& e);
428

429 430 431 432 433
/*!
 * \brief Fold constant expressions.
 * \param expr the expression to be optimized.
 * \return The optimized expression.
 */
434
TVM_DLL Expr FoldConstant(const Expr& expr);
435 436 437 438 439 440 441

/*!
 * \brief Fuse operations into expr into seperate functions.
 * \param expr The expression.
 * \param fuse_opt_level Optimization level.
 * \return The optimized expression.
 */
442
TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level);
443

444 445 446 447 448 449
/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order.
 * \param expr The expression.
 * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
 *                              rule function.
 * \param fcontext Additional callback to provide context argument for each call node.
450 451
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
452 453
 * \return The rewritten expression.
 */
454
TVM_DLL Expr ForwardRewrite(const Expr& expr,
455
                    const std::string& rewrite_map_attr_name,
456 457
                    std::function<NodeRef(const Call&)> fcontext = nullptr,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
458

459 460 461 462 463 464 465 466 467
/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order.
 * \param expr The expression.
 * \param rewrite_func The rewrite func that will apply to all operators.
 * \param fcontext Additional callback to provide context argument for each call node.
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
 * \return The rewritten expression.
 */
468
TVM_DLL Expr ForwardRewrite(const Expr& expr,
469 470 471 472
                    const FForwardRewrite& rewrite_func,
                    std::function<NodeRef(const Call&)> fcontext = nullptr,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

473 474 475 476 477 478 479
/*!
 * \brief Rewrite the annotated program.
 * \param expr The expression.
 * \param fallback_device The fallback device which is the default device for
 *                        operators without annotation.
 * \return The updated program.
 */
480
TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
481 482 483 484 485 486

/*!
 * \brief Collect the device mapping information of each expression.
 * \param expr The expression.
 * \return The device mapping.
 */
487
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
488

489 490 491 492 493 494 495 496 497 498 499
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
  /*! \brief Hash a Relay type.
   *
   * Implements structural hashing of a Relay type.
   *
   *  \param type the type to hash.
   *
   *  \return the hash value.
   */
  size_t operator()(const Type& type) const;
500

501 502 503 504 505 506 507 508 509 510
  /*! \brief Hash a Relay expression.
   *
   * Implements structural hashing of a Relay expression.
   *
   * \param expr the expression to hash.
   *
   * \return the hash value.
   */
  size_t operator()(const Expr& expr) const;
};
511

512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
 *
 * It will turn an expression that is in a graph form (with sharing implicit),
 * to an expression with explicit sharing (A-Normal Form).
 *
 * The scope of the root expression is the global scope.

 * The scope of any non root expression is the least common ancestor of all it's scope.
 *
 * Values are ordered by post-DFS order in each scope.
 *
 * \param e the expression to observably share
 *
 * \param mod The module used for referencing global functions, can be
 * None.
 *
 * \return expression in A-Normal Form
 */
雾雨魔理沙 committed
530 531 532 533 534 535 536 537 538 539 540 541
Expr ToANormalForm(const Expr& e, const Module& mod);

/*! \brief Remove let binding and directly share via pointer instead.
 *
 * It will remove all let binding,
 * and turn all of the variable bound by let into direct pointer reference.
 *
 * \param e the expression.
 *
 * \return the expression in graph normal form.
 */
Expr ToGraphNormalForm(const Expr& e);
542

543 544
}  // namespace relay
}  // namespace tvm
545

546
#endif  // TVM_RELAY_PASS_H_