pass.h 18.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22
/*!
 * \file tvm/relay/pass.h
 * \brief The set of Relay passes written in C++.
Zhi committed
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
 *
 * This file also implements a pass manager. The pass manager manages a sequence
 * of Relay-to-Relay transformation passes over a particlar unit of AST. The
 * design is largely inspired from LLVM's pass manager and modern deep learning
 * frameworks that perform tensor->tensor transformations.
 *
 * The responsibilities of a traditional compiler pass manager usually involves:
 *  - Organizing the execution order of optimization passes though not
 * necessarily in the optimal sequence.
 *  - Collecting required analysis information and keep them up-to-date.
 *  - Reducing the effort required to implement new passes for compiler
 * developers, etc.
 *
 * Similar to LLVM's pass manager, we designed the Relay pass manager to work
 * different granularity, i.e. module level, function level, and even sequential
 * passe that contains a host of passes.
 *
 * However, we also extend the functionality of the traditional pass manager
 * with the consideration of requirements/convention from deep learning
 * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
 * manager performs the Relay.Module -> Relay.Module transformation. All
 * different types of passes, including the sequential-level pass object, are
 * essentially pass objects. This design, therefore, effectively provides users
 * a consistent and convenient interface, i.e. Pass, to play with. It offers a
 * means to ease the development and testing of Relay passes. For example, with
 * the pass manager, external users will be able to have custom passes correctly
 * scheduled without having to modify a single handcrafted pass order.
 *
 * In the future we need to describe constraints between passes. For example,
 * we may want to preserve dependencies between different passes and validate
 * them on the completion of a certain pass.
 *
 * We also need to store side information and import the error reporting system.
56 57 58 59
 */
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

Zhi committed
60 61 62
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
63
#include <tvm/relay/expr.h>
64
#include <tvm/relay/module.h>
65
#include <tvm/relay/op_attr_types.h>
Zhi committed
66
#include <tvm/relay/type.h>
雾雨魔理沙 committed
67
#include <tvm/relay/adt.h>
68
#include <string>
Zhi committed
69
#include <vector>
70 71 72 73

namespace tvm {
namespace relay {

Zhi committed
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
namespace pass {

/*
 * \brief The context of pass.
 */
class PassContext;

/*!
 * \brief PassContextNode contains the information that a pass can rely on, such as
 * analysis results.
 */
class PassContextNode : public RelayNode {
 public:
  /*!
   * \brief The error reporter used to notify users why an optimization fails.
   */
  ErrorReporter err_reporter;

  PassContextNode() = default;

  void VisitAttrs(tvm::AttrVisitor* v) final {
  }

  TVM_DLL static PassContext make();

  static constexpr const char* _type_key = "relay.PassContext";
  TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};

TVM_DEFINE_NODE_REF(PassContext, PassContextNode)

/*
 * \brief The meta data of a pass.
 *
 * PassInfo can be extended conveniently in the future if more meta information
 * is needed.
 */
class PassInfo;

/*!
 * \brief PassInfoNode contains meta data that will be used to help optimization
 * and analysis.
 */
class PassInfoNode : public RelayNode {
 public:
  /*! \brief The minimal optimization level that this pass will be enabled. */
  int opt_level;

  /*! \brief The name of an optimization/analysis pass. */
  std::string name;

  /*! \brief The passes that are required to perform the current pass. */
  tvm::Array<tvm::Expr> required;

  PassInfoNode() = default;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("opt_level", &opt_level);
    v->Visit("name", &name);
    v->Visit("required", &required);
  }

  TVM_DLL static PassInfo make(int opt_level, std::string name,
                               tvm::Array<tvm::Expr> required);

  static constexpr const char* _type_key = "relay.PassInfo";
  TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
};

TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)

class Pass;

/*!
 * \brief PassNode is the base type of differnt types of optimization passes.
 * It is designed as a pure class and implemented by different pass subclasses
 * at different granularity of Relay nodes.
 */
class PassNode : public RelayNode {
 public:
  /*
   * \brief Get the pass information/meta data. */
  virtual PassInfo Info() const = 0;

  /*!
   * \brief Set the context information for a pass.
   *
   * \param pass_ctx The context information for a certain pass.
   */
  virtual void SetContext(const PassContext& pass_ctx) = 0;

  /*!
   * \brief Execute the optimization pass using a functor.
   *
   * \param mod The module that an optimization pass runs on.
   *
   * \return The updated module.
   */
  virtual Module operator()(const Module& mod) const = 0;

  void VisitAttrs(tvm::AttrVisitor* v) override {}

  static constexpr const char* _type_key = "relay.Pass";
  TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
};

class Pass : public NodeRef {
 public:
  Pass() = default;
  explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}

  PassNode* operator->() const {
    return static_cast<PassNode*>(this->node_.get());
  }

  using ContainerType = PassNode;
};

/*
 * \brief Create a module pass.
 *
 * \param pass_func The packed function that contains the optimization.
 * \param opt_level The optimization level of the module pass.
 * \param name The name of the module pass.
 * \param required The list of the passes that the module pass is dependent on.
 *
 * \return The created module pass.
 */
Pass CreateModulePass(
    const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
    int opt_level,
    const std::string& name,
    const tvm::Array<tvm::Expr>& required);

/*
 * \brief Create a function pass.
 *
 * \param pass_func The packed function that contains the optimization.
 * \param opt_level The optimization level of the function pass.
 * \param name The name of the function pass.
 * \param required The list of the passes that the function pass is dependent on.
 *
 * \return The created function pass.
 */
Pass CreateFunctionPass(
    const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
    int opt_level,
    const std::string& name,
    const tvm::Array<tvm::Expr>& required);
/*
 * \brief Create a sequential pass.
 *
 * \param passes The optimization passes will be performed.
 * \param opt_level The optimization level of the sequential pass.
 * \param name The name of the sequential pass.
 * \param required The list of the passes that the sequential pass is dependent on.
 * \param disabled The disabled passes.
 *
 * \return The created sequential pass.
 */
Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
                          int opt_level,
                          const std::string& name,
                          const tvm::Array<tvm::Expr>& required,
                          const tvm::Array<tvm::Expr>& disabled);

}  // namespace pass

242 243
/*!
 * \brief Infer the type of an expression.
244 245 246 247 248
 *
 * The result of type checking is a new expression with unambigous
 * type information filled in, as well as it's checked type field
 * populated with the result type.
 *
249
 * \param expr The expression to type check.
250
 * \param mod The module used for referencing global functions, can be
251
 * None.
252 253 254
 *
 * \return A type checked expression with its checked_type field populated.
 */
255
TVM_DLL Expr InferType(const Expr& expr, const Module& mod);
256

257
/*!
258
 * \brief Infer the type of a function as if it is mapped to var in the mod.
259 260
 *
 * \param f the function.
261
 * \param mod The module used for referencing global functions.
262 263 264
 * \param var The global variable corresponding to the function.
 *
 * \return A type checked Function with its checked_type field populated.
265
 * \note this function mutates mod and is not thread-safe.
266
 */
267 268
TVM_DLL Function InferType(const Function& f, const Module& mod,
                           const GlobalVar& var);
269 270

/*!
271
 * \brief Check that types are well kinded by applying "kinding rules".
272 273 274 275 276 277 278 279 280 281
 *
 * This pass ensures we do not do things that violate the design of the
 * type system when writing down types.
 *
 * For example tensors are not allowed to contain functions in Relay.
 *
 * We check this by ensuring the `dtype` field of a Tensor always contains
 * a data type such as `int`, `float`, `uint`.
 *
 * \param t The type to check.
282
 * \param mod The global module.
283
 *
284
 * \return The kind of the passed type.
285
 */
286
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302

/*! \brief Compare two expressions for structural equivalence.
 *
 * This comparison operator respects scoping and compares
 * expressions without regard to variable choice.
 *
 * For example: `let x = 1 in x` is equal to `let y = 1 in y`.
 *
 *   See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
 *   for more details.
 *
 *   \param e1 The left hand expression.
 *   \param e2 The right hand expression.
 *
 *   \return true if equal, otherwise false
 */
303
TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
304 305 306 307 308 309 310 311

/*! \brief Compare two types for structural equivalence.
 *
 * This comparison operator respects scoping and compares
 * expressions without regard to variable choice.
 *
 * For example: `forall s, Tensor[f32, s]` is equal to
 * `forall w, Tensor[f32, w]`.
312
 *
313 314 315 316 317 318 319 320
 * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
 * for more details.
 *
 * \param t1 The left hand type.
 * \param t2 The right hand type.
 *
 * \return true if equal, otherwise false
 */
321
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
322

323
/*! \brief Check that each Var is only bound once.
324 325 326
 *
 * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
 *
327 328
 * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice,
 * although x is not shadowed.
329
 *
330
  * \param expr the expression to check.
331
 *
332
  * \return true iff all Var in expr is bound at most once.
333
 */
334
TVM_DLL bool WellFormed(const Expr& expr);
335

336 337 338 339 340 341 342 343 344
/*! \brief Get all bound variables from expression expr.
 *
 * Bound variables are all variables that are declared in the expr.
 * They only have meaning inside that expr, and can only be used in it.
 *
 * \param expr the expression.
 *
 * \return List of bound vars, in the PostDFS order in the expression.
 */
345
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
346

雾雨魔理沙 committed
347 348 349 350 351 352 353 354 355 356 357
/*! \brief Get all bound variables from pattern pat.
 *
 * Bound variables are all variables that got bound by the pat.
 * They only have meaning inside that expr, and can only be used in it.
 *
 * \param pat the Pattern.
 *
 * \return List of bound vars, in the PostDFS order in the expression.
 */
TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);

358
/*! \brief Get free type parameters from expression expr.
359
 *
360 361
 * Free variables are variables that are not bound by a
 * let or a function parameter in the context.
362
 *
363
 * \param expr the expression.
364
 *
365
 * \return List of free vars, in the PostDFS order in the expression.
366
 */
367
TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
368

369 370 371 372 373 374
/*! \brief Get all variables from expression expr.
 *
 * \param expr the expression.
 *
 * \return List of all vars, in the PostDFS order in the expression.
 */
375
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
376

377
/*! \brief Get free TypeVars from expression expr.
378
 *
379 380
 * Free type parameters are type parameters that are not bound by a function
 * type in the context.
381
 *
382
 * \param expr the expression.
383
 * \param mod the module.
384
 *
385
 * \return List of free vars, in the PostDFS order visited by expr.
386
 */
387
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
388

389 390 391 392 393 394
/*! \brief Get free TypeVars from type t.
 *
 * Free type parameters are type parameters that are not bound by a function
 * type in the context.
 *
 * \param t the type.
395
 * \param mod the module.
396 397 398
 *
 * \return List of free type vars, in the PostDFS order visited by type.
 */
399
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
400 401 402 403 404 405 406

/*! \brief Get all bound type variables from expression expr.
 *
 * Bound variables are all type variables that are declared in the expr.
 * They only have meaning inside that expr, and can only be used in it.
 *
 * \param expr the expression.
407
 * \param mod the module.
408 409 410
 *
 * \return List of bound type vars, in the PostDFS order in the expression.
 */
411
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
412 413 414 415 416 417 418

/*! \brief Get all bound type variables from type t.
 *
 * Bound variables are all type variables that are declared in the type.
 * They only have meaning inside that type, and can only be used in it.
 *
 * \param t the type
419
 * \param mod the module.
420 421 422
 *
 * \return List of bound type vars, in the PostDFS order visited by type.
 */
423
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
424 425 426 427

/*! \brief Get all type variables in expression expr.
 *
 * \param expr the expression.
428
 * \param mod the module.
429 430 431
 *
 * \return List of type vars, in the PostDFS order in the expression.
 */
432
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
433 434 435 436

/*! \brief Get all type variables in type t.
 *
 * \param t the type.
437
 * \param mod the module.
438 439 440
 *
 * \return List of type vars, in the PostDFS order visited by type.
 */
441
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
442

443 444
/*! \brief Remove expressions which does not effect the program result.
 *
雾雨魔理沙 committed
445 446
 * It will remove let bindings which are not referenced,
 * and inline let bindings that are only used once.
447
 *
雾雨魔理沙 committed
448 449 450 451
 * For example, this pass should turn `let a = 1 in 2` into `2`,
 * as the value of the expression does not depend on a.
 *
 * As another example, `let a = 1 in a` will be optimized into 1.
452 453 454 455 456
 *
 * \param e the expression to optimize.
 *
 * \return the optimized expression.
 */
457
TVM_DLL Expr DeadCodeElimination(const Expr& e);
458

459 460 461 462 463
/*!
 * \brief Fold constant expressions.
 * \param expr the expression to be optimized.
 * \return The optimized expression.
 */
464
TVM_DLL Expr FoldConstant(const Expr& expr);
465 466 467 468 469 470 471

/*!
 * \brief Fuse operations into expr into seperate functions.
 * \param expr The expression.
 * \param fuse_opt_level Optimization level.
 * \return The optimized expression.
 */
472
TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level);
473

474 475 476 477 478 479
/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order.
 * \param expr The expression.
 * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
 *                              rule function.
 * \param fcontext Additional callback to provide context argument for each call node.
480 481
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
482 483
 * \return The rewritten expression.
 */
484
TVM_DLL Expr ForwardRewrite(const Expr& expr,
485
                    const std::string& rewrite_map_attr_name,
486 487
                    std::function<NodeRef(const Call&)> fcontext = nullptr,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
488

489 490 491 492 493 494 495 496 497
/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order.
 * \param expr The expression.
 * \param rewrite_func The rewrite func that will apply to all operators.
 * \param fcontext Additional callback to provide context argument for each call node.
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
 * \return The rewritten expression.
 */
498
TVM_DLL Expr ForwardRewrite(const Expr& expr,
499 500 501 502
                    const FForwardRewrite& rewrite_func,
                    std::function<NodeRef(const Call&)> fcontext = nullptr,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

503 504 505 506 507 508 509
/*!
 * \brief Rewrite the annotated program.
 * \param expr The expression.
 * \param fallback_device The fallback device which is the default device for
 *                        operators without annotation.
 * \return The updated program.
 */
510
TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
511 512 513 514 515 516

/*!
 * \brief Collect the device mapping information of each expression.
 * \param expr The expression.
 * \return The device mapping.
 */
517
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
518

519 520 521 522 523 524 525 526 527 528 529
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
  /*! \brief Hash a Relay type.
   *
   * Implements structural hashing of a Relay type.
   *
   *  \param type the type to hash.
   *
   *  \return the hash value.
   */
  size_t operator()(const Type& type) const;
530

531 532 533 534 535 536 537 538 539 540
  /*! \brief Hash a Relay expression.
   *
   * Implements structural hashing of a Relay expression.
   *
   * \param expr the expression to hash.
   *
   * \return the hash value.
   */
  size_t operator()(const Expr& expr) const;
};
541

542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
 *
 * It will turn an expression that is in a graph form (with sharing implicit),
 * to an expression with explicit sharing (A-Normal Form).
 *
 * The scope of the root expression is the global scope.

 * The scope of any non root expression is the least common ancestor of all it's scope.
 *
 * Values are ordered by post-DFS order in each scope.
 *
 * \param e the expression to observably share
 *
 * \param mod The module used for referencing global functions, can be
 * None.
 *
 * \return expression in A-Normal Form
 */
雾雨魔理沙 committed
560
TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
雾雨魔理沙 committed
561 562 563 564 565 566 567 568 569 570

/*! \brief Remove let binding and directly share via pointer instead.
 *
 * It will remove all let binding,
 * and turn all of the variable bound by let into direct pointer reference.
 *
 * \param e the expression.
 *
 * \return the expression in graph normal form.
 */
雾雨魔理沙 committed
571
TVM_DLL Expr ToGraphNormalForm(const Expr& e);
572

雾雨魔理沙 committed
573 574 575 576 577 578
/*! \brief Aggressive constant propagation/constant folding/inlining.
 * It will do as much computation in compile time as possible.
 * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
 * As a side effect, code size will explode.
 */
Expr PartialEval(const Expr& e);
579 580
}  // namespace relay
}  // namespace tvm
581

582
#endif  // TVM_RELAY_PASS_H_