transform.h 16.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tvm/relay/transform.h
 *
 * This file implements a pass manager. The pass manager manages a sequence
 * of Relay-to-Relay transformation passes over a particlar unit of AST. The
 * design is largely inspired from LLVM's pass manager and modern deep learning
 * frameworks that perform tensor->tensor transformations.
 *
 * The responsibilities of a traditional compiler pass manager usually involves:
 *  - Organizing the execution order of optimization passes though not
 * necessarily in the optimal sequence.
 *  - Collecting required analysis information and keep them up-to-date.
 *  - Reducing the effort required to implement new passes for compiler
 * developers, etc.
 *
 * Similar to LLVM's pass manager, we designed the Relay pass manager to work
 * different granularity, i.e. module level, function level, and even sequential
 * passe that contains a host of passes.
 *
 * However, we also extend the functionality of the traditional pass manager
 * with the consideration of requirements/convention from deep learning
 * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
 * manager performs the Relay.Module -> Relay.Module transformation. All
 * different types of passes, including the sequential-level pass object, are
 * essentially pass objects. This design, therefore, effectively provides users
 * a consistent and convenient interface, i.e. Pass, to play with. It offers a
 * means to ease the development and testing of Relay passes. For example, with
 * the pass manager, external users will be able to have custom passes correctly
 * scheduled without having to modify a single handcrafted pass order.
 *
 * In the future we need to describe constraints between passes. For example,
 * we may want to preserve dependencies between different passes and validate
 * them on the completion of a certain pass.
 *
 * We also need to store side information and import the error reporting system.
 */
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_

59
#include <tvm/base.h>
60
#include <tvm/packed_func_ext.h>
61
#include <tvm/relay/attrs/transform.h>
62 63 64
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
65
#include <tvm/relay/op.h>
66
#include <tvm/relay/op_attr_types.h>
67
#include <string>
68
#include <unordered_map>
69 70 71 72 73 74 75 76 77 78 79 80
#include <vector>

namespace tvm {
namespace relay {
namespace transform {

/*
 * \brief The context of pass.
 */
class PassContext;

/*!
81 82
 * \brief PassContextNode contains the information that a pass can rely on,
 * such as analysis results.
83 84 85 86 87 88 89 90
 */
class PassContextNode : public RelayNode {
 public:
  /*!
   * \brief The error reporter used to notify users why an optimization fails.
   */
  ErrorReporter err_reporter;

91 92 93 94 95 96 97 98 99 100 101
  /*! \brief The default optimization level. */
  int opt_level{2};

  /*! \brief CPU is the default fallback device for heterogeneous execution. */
  int fallback_device{static_cast<int>(kDLCPU)};

  /*! \brief The list of required passes. */
  tvm::Array<tvm::Expr> required_pass;
  /*! \brief The list of disabled passes. */
  tvm::Array<tvm::Expr> disabled_pass;

102 103 104
  PassContextNode() = default;

  void VisitAttrs(tvm::AttrVisitor* v) final {
105 106 107 108
    v->Visit("opt_level", &opt_level);
    v->Visit("fallback_device", &fallback_device);
    v->Visit("required_pass", &required_pass);
    v->Visit("disabled_pass", &disabled_pass);
109 110 111 112 113 114
  }

  static constexpr const char* _type_key = "relay.PassContext";
  TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};

115 116 117 118 119 120 121 122 123 124 125 126 127
/*!
 * \brief PassContext that is used to configure the pass behavior.
 *
 * \code
 *
 *  auto new_ctx = PassContext::Create();
 *  ctx->opt_level = 2;
 *  ctx->fallback_device = kDLCPU;
 *  With<PassContext> scope(ctx);
 *  // pass context in effect.
 *
 * \endcode
 */
128 129 130
class PassContext : public NodeRef {
 public:
  PassContext() {}
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
  explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {}
  /*!
   * \brief const accessor.
   * \return const access pointer.
   */
  const PassContextNode* operator->() const {
    CHECK(node_.get() != nullptr);
    return static_cast<const PassContextNode*>(node_.get());
  }
  /*!
   * \brief mutable accessor.
   * \return mutable access pointer.
   */
  PassContextNode* operator->() {
    CHECK(node_.get() != nullptr);
    return static_cast<PassContextNode*>(node_.get());
  }
  /*!
   * \brief Construct a PassContext containing the default configurations.
   * \return The new PassContext.
   */
  TVM_DLL static PassContext Create();
  /*!
   * \brief Get the default pass context in the current scope.
   * \return The pass context.
156 157 158
   */
  TVM_DLL static PassContext Current();

159
  // accessor.
160 161 162 163 164 165 166 167 168 169 170 171 172
  using ContainerType = PassContextNode;
  class Internal;

 private:
  // The entry of a pass context scope.
  TVM_DLL void EnterWithScope();
  // The exit of a pass context scope.
  TVM_DLL void ExitWithScope();

  // Classes to get the Python `with` like syntax.
  friend class Internal;
  friend class tvm::With<PassContext>;
};
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222

/*
 * \brief The meta data of a pass.
 *
 * PassInfo can be extended conveniently in the future if more meta information
 * is needed.
 */
class PassInfo;

/*!
 * \brief PassInfoNode contains meta data that will be used to help optimization
 * and analysis.
 */
class PassInfoNode : public RelayNode {
 public:
  /*! \brief The minimal optimization level that this pass will be enabled. */
  int opt_level;

  /*! \brief The name of an optimization/analysis pass. */
  std::string name;

  /*! \brief The passes that are required to perform the current pass. */
  tvm::Array<tvm::Expr> required;

  PassInfoNode() = default;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("opt_level", &opt_level);
    v->Visit("name", &name);
    v->Visit("required", &required);
  }

  TVM_DLL static PassInfo make(int opt_level, std::string name,
                               tvm::Array<tvm::Expr> required);

  static constexpr const char* _type_key = "relay.PassInfo";
  TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
};

TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)

class Pass;

/*!
 * \brief PassNode is the base type of differnt types of optimization passes.
 * It is designed as a pure class and implemented by different pass subclasses
 * at different granularity of Relay nodes.
 */
class PassNode : public RelayNode {
 public:
223
  /*!
224 225 226 227
   * \brief Get the pass information/meta data. */
  virtual PassInfo Info() const = 0;

  /*!
228
   * \brief Transform mod using the default PassContext in the current scope.
229 230
   *
   * \param mod The module that an optimization pass runs on.
231
   *
232
   * \return The transformed module.
233
   */
234 235 236
  Module operator()(const Module& mod) const {
    return this->operator()(mod, PassContext::Current());
  }
237 238

  /*!
239
   * \brief Transform mod using a functor under a given pass context.
240 241
   *
   * \param mod The module that an optimization pass runs on.
242
   * \param pass_ctx The pass context that can provide information for the optimization.
243
   *
244
   * \return The transformed module.
245
   */
246 247
  virtual Module operator()(const Module& mod,
                            const PassContext& pass_ctx) const = 0;
248 249 250 251 252 253 254 255 256

  void VisitAttrs(tvm::AttrVisitor* v) override {}

  static constexpr const char* _type_key = "relay.Pass";
  TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
};

class Pass : public NodeRef {
 public:
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
  /*!
   * \brief Transform mod using the default PassContext in the current scope.
   *
   * \param mod The module that an optimization pass runs on.
   *
   * \return The transformed module.
   */
  Module operator()(const Module& mod) const {
    const PassNode* node = operator->();
    CHECK(node != nullptr);
    return node->operator()(mod);
  }
  /*!
   * \brief Transform mod using a functor under a given pass context.
   *
   * \param mod The module that an optimization pass runs on.
   * \param pass_ctx The pass context that can provide information for the optimization.
   *
   * \return The transformed module.
   */
  Module operator()(const Module& mod,
                    const PassContext& pass_ctx) const {
    const PassNode* node = operator->();
    CHECK(node != nullptr);
    return node->operator()(mod, pass_ctx);
282 283
  }

284
  TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode);
285 286 287 288 289 290 291 292
};

class SequentialNode;

class Sequential : public Pass {
 public:
  /*!
   * \brief The constructor of `Sequential`.
293
   *
294 295 296
   * \param passes The passes to apply.
   * \param pass_info The pass metadata.
   */
297 298 299
  TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info);

  /*!
300 301 302 303 304 305 306 307 308
   * \brief The constructor of `Sequential`.
   *
   * \param passes The passes to apply.
   * \param name The name of a sequential pass. It's defaulted to "sequential".
   *        This allows users to only provide a list of passes and execute them
   *        under a given context.
   */
  TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");

309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
  Sequential() = default;
  explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}

  const SequentialNode* operator->() const;
  using ContainerType = Sequential;
};

/*
 * \brief Create a module pass.
 *
 * \param pass_func The packed function that contains the optimization.
 * \param opt_level The optimization level of the module pass.
 * \param name The name of the module pass.
 * \param required The list of the passes that the module pass is dependent on.
 *
 * \return The created module pass.
 */
Pass CreateModulePass(
    const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
    int opt_level,
    const std::string& name,
    const tvm::Array<tvm::Expr>& required);

/*
 * \brief Create a function pass.
 *
 * \param pass_func The packed function that contains the optimization.
 * \param opt_level The optimization level of the function pass.
 * \param name The name of the function pass.
 * \param required The list of the passes that the function pass is dependent on.
 *
 * \return The created function pass.
 */
342
TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
343
                                Function(Function, Module, PassContext)>& pass_func,
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
                                int opt_level,
                                const std::string& name,
                                const tvm::Array<tvm::Expr>& required);

/*! \brief Remove expressions which does not effect the program result.
 *
 * It will remove let bindings which are not referenced,
 * and inline let bindings that are only used once.
 *
 * For example, this pass should turn `let a = 1 in 2` into `2`,
 * as the value of the expression does not depend on a.
 *
 * As another example, `let a = 1 in a` will be optimized into 1.
 *
 * \return the pass.
 */
TVM_DLL Pass DeadCodeElimination();

/*!
 * \brief Fold constant expressions.
 *
 * \return The pass.
 */
TVM_DLL Pass FoldConstant();

/*!
 * \brief Fuse operations into expr into seperate functions.
 *
 * \param fuse_opt_level Optimization level. If it is -1 it will be inferred from pass context.
 *
 * \return The pass.
 */
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);

/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order.
 *
 * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
 *                              rule function.
 * \param fcontext Additional callback to provide context argument for each call node.
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
 *
 * \return The pass.
 */
TVM_DLL Pass ForwardRewrite(const std::string& rewrite_map_attr_name,
                            std::function<NodeRef(const Call&)> fcontext = nullptr,
                            std::function<Expr(const Expr&)>
                            fmulti_ref_trigger = nullptr);

/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order.
 *
 * \param rewrite_func The rewrite func that will apply to all operators.
 * \param fcontext Additional callback to provide context argument for each call node.
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
 *
 * \return The pass.
 */
TVM_DLL Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
                            std::function<NodeRef(const Call&)> fcontext = nullptr,
                            std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
 * \brief Rewrite the annotated program.
 *
 * \param fallback_device The fallback device which is the default device for
 *                        operators without annotation.
 *
 * \return The pass.
 */
TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);

/*!
 * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
 *
 * It will turn an expression that is in a graph form (with sharing implicit),
 * to an expression with explicit sharing (A-Normal Form).
 *
 * The scope of the root expression is the global scope.
 *
 * The scope of any non root expression is the least common ancestor of all it's scope.
 *
 * Values are ordered by post-DFS order in each scope.
 *
 * \return The pass.
 */
TVM_DLL Pass ToANormalForm();

/*!
 * \brief Remove let binding and directly share via pointer instead.
 *
 * It will remove all let binding,
 * and turn all of the variable bound by let into direct pointer reference.
 *
 * \return the expression in graph normal form.
 */
TVM_DLL Pass ToGraphNormalForm();

/*!
 * \brief Aggressive constant propagation/constant folding/inlining.
 *
 * It will do as much computation in compile time as possible.
 * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
 * As a side effect, code size will explode.
 *
 * \return the optimized expression.
 */
TVM_DLL Pass PartialEval();
454

455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
/*!
 * \brief Simplify certain operators during inference. For example, batch norm
 * will be unpacked into a number of simplified operators.
 *
 * \return The Pass.
 */
TVM_DLL Pass SimplifyInference();

/*!
 * \brief Infer the type of an expression.
 *
 * The result of type checking is a new expression with unambigous
 * type information filled in, as well as it's checked type field
 * populated with the result type.
 *
 * \return The pass. 
 */
TVM_DLL Pass InferType();

/*!
 * \brief Search and eliminate common subexpression. For example, if there are
 * two expressions evaluated to an identical value, a single variable is created
 * and these two expressions are replaced by this variable.
 *
 * \param fskip The callback argument that allows to skip certain expressions.
 *
 * \return The pass.
 */
TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);

/*!
 * \brief Combine parallel 2d convolutions into a single convolution if the
 * number of branches of this conv2d operator is not less than
 * `min_num_branch`.
 *
 * \param min_num_branches The minimun number of branches.
 *
 * \return The pass.
 */
TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);

/*!
 * \brief Backward fold axis scaling into weights of conv/dense operators.
 *
 * \return The pass.
 */
TVM_DLL Pass BackwardFoldScaleAxis();

/*!
 * \brief Forward fold axis scaling into weights of conv/dense operators.
 *
 * \return The pass.
 */
TVM_DLL Pass ForwardFoldScaleAxis();

/*!
 * \brief A sequential pass that executes ForwardFoldScaleAxis and
 * BackwardFoldScaleAxis passes.
 *
 * \return The pass.
 */
TVM_DLL Pass FoldScaleAxis();

/*!
 * \brief Canonicalize some operators to the simplified operators. For example,
 * bias_add can be canonicalized to expand_dims and broadcast_add.
 *
 * \return The pass.
 */
TVM_DLL Pass CanonicalizeOps();

/*!
 * \brief Alternate the layouts of operators or replace primitive operators
 * with other expressions.
 *
 * \return The pass.
 */
TVM_DLL Pass AlterOpLayout();

534 535 536 537 538
}  // namespace transform
}  // namespace relay
}  // namespace tvm

#endif  // TVM_RELAY_TRANSFORM_H_