pass.h 13.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22
/*!
 * \file tvm/relay/pass.h
 * \brief The set of Relay passes written in C++.
23
  */
24 25 26
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

Zhi committed
27 28
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
29
#include <tvm/relay/expr.h>
30
#include <tvm/relay/module.h>
31
#include <tvm/relay/op_attr_types.h>
Zhi committed
32
#include <tvm/relay/type.h>
雾雨魔理沙 committed
33
#include <tvm/relay/adt.h>
34
#include <tvm/relay/transform.h>
35
#include <tvm/runtime/vm.h>
36
#include <string>
Zhi committed
37
#include <vector>
38 39 40 41

namespace tvm {
namespace relay {

42 43
/*!
 * \brief Infer the type of an expression.
44 45 46 47 48
 *
 * The result of type checking is a new expression with unambigous
 * type information filled in, as well as it's checked type field
 * populated with the result type.
 *
49
 * \param expr The expression to type check.
50
 * \param mod The module used for referencing global functions, can be
51
 * None.
52 53 54
 *
 * \return A type checked expression with its checked_type field populated.
 */
55
TVM_DLL Expr InferType(const Expr& expr, const Module& mod);
56

57
/*!
58
 * \brief Infer the type of a function as if it is mapped to var in the mod.
59 60
 *
 * \param f the function.
61
 * \param mod The module used for referencing global functions.
62 63 64
 * \param var The global variable corresponding to the function.
 *
 * \return A type checked Function with its checked_type field populated.
65
 * \note this function mutates mod and is not thread-safe.
66
 */
67 68
TVM_DLL Function InferType(const Function& f, const Module& mod,
                           const GlobalVar& var);
69 70

/*!
71
 * \brief Check that types are well kinded by applying "kinding rules".
72 73 74 75 76 77 78 79 80 81
 *
 * This pass ensures we do not do things that violate the design of the
 * type system when writing down types.
 *
 * For example tensors are not allowed to contain functions in Relay.
 *
 * We check this by ensuring the `dtype` field of a Tensor always contains
 * a data type such as `int`, `float`, `uint`.
 *
 * \param t The type to check.
82
 * \param mod The global module.
83
 *
84
 * \return The kind of the passed type.
85
 */
86
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
87

88 89
/*!
 * \brief Compare two expressions for structural equivalence.
90 91 92 93 94 95 96 97 98 99 100 101 102 103
 *
 * This comparison operator respects scoping and compares
 * expressions without regard to variable choice.
 *
 * For example: `let x = 1 in x` is equal to `let y = 1 in y`.
 *
 *   See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
 *   for more details.
 *
 *   \param e1 The left hand expression.
 *   \param e2 The right hand expression.
 *
 *   \return true if equal, otherwise false
 */
104
TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
105

106 107
/*!
 * \brief Compare two types for structural equivalence.
108 109 110 111 112 113
 *
 * This comparison operator respects scoping and compares
 * expressions without regard to variable choice.
 *
 * For example: `forall s, Tensor[f32, s]` is equal to
 * `forall w, Tensor[f32, w]`.
114
 *
115 116 117 118 119 120 121 122
 * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
 * for more details.
 *
 * \param t1 The left hand type.
 * \param t2 The right hand type.
 *
 * \return true if equal, otherwise false
 */
123
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
124

125 126
/*!
 * \brief Add abstraction over a function
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
 *
 * For example: `square` is transformed to
 * `fun x -> square x`.
 *
 * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
 * for more details.
 *
 * \param e The original function.
 * \param mod The module used for referencing global functions, can be
 * None.
 *
 * \return the new function with abstraction
 */
TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);

142 143
/*!
 * \brief Check that each Var is only bound once.
144 145 146
 *
 * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
 *
147 148
 * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice,
 * although x is not shadowed.
149
 *
150
  * \param expr the expression to check.
151
 *
152
  * \return true iff all Var in expr is bound at most once.
153
 */
154
TVM_DLL bool WellFormed(const Expr& expr);
155

156 157
/*!
 * \brief Get all bound variables from expression expr.
158 159 160 161 162 163 164 165
 *
 * Bound variables are all variables that are declared in the expr.
 * They only have meaning inside that expr, and can only be used in it.
 *
 * \param expr the expression.
 *
 * \return List of bound vars, in the PostDFS order in the expression.
 */
166
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
167

168 169
/*!
 * \brief Get all bound variables from pattern pat.
雾雨魔理沙 committed
170 171 172 173 174 175 176 177 178 179
 *
 * Bound variables are all variables that got bound by the pat.
 * They only have meaning inside that expr, and can only be used in it.
 *
 * \param pat the Pattern.
 *
 * \return List of bound vars, in the PostDFS order in the expression.
 */
TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);

180 181
/*!
 * \brief Get free type parameters from expression expr.
182
 *
183 184
 * Free variables are variables that are not bound by a
 * let or a function parameter in the context.
185
 *
186
 * \param expr the expression.
187
 *
188
 * \return List of free vars, in the PostDFS order in the expression.
189
 */
190
TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
191

192 193
/*!
 * \brief Get all variables from expression expr.
194 195 196 197 198
 *
 * \param expr the expression.
 *
 * \return List of all vars, in the PostDFS order in the expression.
 */
199
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
200

201 202
/*!
 * \brief Get free TypeVars from expression expr.
203
 *
204 205
 * Free type parameters are type parameters that are not bound by a function
 * type in the context.
206
 *
207
 * \param expr the expression.
208
 * \param mod the module.
209
 *
210
 * \return List of free vars, in the PostDFS order visited by expr.
211
 */
212
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
213

214 215
/*!
 * \brief Get free TypeVars from type t.
216 217 218 219 220
 *
 * Free type parameters are type parameters that are not bound by a function
 * type in the context.
 *
 * \param t the type.
221
 * \param mod the module.
222 223 224
 *
 * \return List of free type vars, in the PostDFS order visited by type.
 */
225
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
226

227 228
/*!
 * \brief Get all bound type variables from expression expr.
229 230 231 232 233
 *
 * Bound variables are all type variables that are declared in the expr.
 * They only have meaning inside that expr, and can only be used in it.
 *
 * \param expr the expression.
234
 * \param mod the module.
235 236 237
 *
 * \return List of bound type vars, in the PostDFS order in the expression.
 */
238
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
239

240 241
/*!
 * \brief Get all bound type variables from type t.
242 243 244 245 246
 *
 * Bound variables are all type variables that are declared in the type.
 * They only have meaning inside that type, and can only be used in it.
 *
 * \param t the type
247
 * \param mod the module.
248 249 250
 *
 * \return List of bound type vars, in the PostDFS order visited by type.
 */
251
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
252

253 254
/*!
 * \brief Get all type variables in expression expr.
255 256
 *
 * \param expr the expression.
257
 * \param mod the module.
258 259 260
 *
 * \return List of type vars, in the PostDFS order in the expression.
 */
261
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
262

263 264
/*!
 * \brief Get all type variables in type t.
265 266
 *
 * \param t the type.
267
 * \param mod the module.
268 269 270
 *
 * \return List of type vars, in the PostDFS order visited by type.
 */
271
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
272

273 274
/*! \brief Remove expressions which does not effect the program result.
 *
雾雨魔理沙 committed
275 276
 * It will remove let bindings which are not referenced,
 * and inline let bindings that are only used once.
277
 *
雾雨魔理沙 committed
278 279 280 281
 * For example, this pass should turn `let a = 1 in 2` into `2`,
 * as the value of the expression does not depend on a.
 *
 * As another example, `let a = 1 in a` will be optimized into 1.
282 283 284 285 286
 *
 * \param e the expression to optimize.
 *
 * \return the optimized expression.
 */
287
TVM_DLL Expr DeadCodeElimination(const Expr& e);
288

289 290
/*!
 * \brief Fold constant expressions.
291
 *
292
 * \param expr the expression to be optimized.
293
 *
294 295
 * \return The optimized expression.
 */
296
TVM_DLL Expr FoldConstant(const Expr& expr);
297 298 299

/*!
 * \brief Fuse operations into expr into seperate functions.
300
 *
301 302
 * \param expr The expression.
 * \param fuse_opt_level Optimization level.
303
 * \param mod the module.
304
 *
305 306
 * \return The optimized expression.
 */
307
TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);
308

309 310
/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order.
311
 *
312 313 314 315
 * \param expr The expression.
 * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
 *                              rule function.
 * \param fcontext Additional callback to provide context argument for each call node.
316 317
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
318 319
 * \return The rewritten expression.
 */
320
TVM_DLL Expr ForwardRewrite(const Expr& expr,
321 322 323
                            const std::string& rewrite_map_attr_name,
                            std::function<NodeRef(const Call&)> fcontext = nullptr,
                            std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
324

325 326
/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order.
327
 *
328 329 330 331 332
 * \param expr The expression.
 * \param rewrite_func The rewrite func that will apply to all operators.
 * \param fcontext Additional callback to provide context argument for each call node.
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
333
 *
334 335
 * \return The rewritten expression.
 */
336
TVM_DLL Expr ForwardRewrite(const Expr& expr,
337 338 339
                            const FForwardRewrite& rewrite_func,
                            std::function<NodeRef(const Call&)> fcontext = nullptr,
                            std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
340

341 342
/*!
 * \brief Rewrite the annotated program.
343
 *
344 345 346
 * \param expr The expression.
 * \param fallback_device The fallback device which is the default device for
 *                        operators without annotation.
347
 *
348 349
 * \return The updated program.
 */
350
TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
351 352 353

/*!
 * \brief Collect the device mapping information of each expression.
354
 *
355
 * \param expr The expression.
356
 *
357 358
 * \return The device mapping.
 */
359
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
360

361
/*!
362 363 364 365 366 367 368 369 370
 * \brief Collect the device anntation operators.
 *
 * \param expr The expression.
 *
 * \return The annotated expression to device type mapping for annotation ops.
 */
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);

/*!
371
 * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
372 373 374 375 376
 *
 * It will turn an expression that is in a graph form (with sharing implicit),
 * to an expression with explicit sharing (A-Normal Form).
 *
 * The scope of the root expression is the global scope.
377
 *
378 379 380 381
 * The scope of any non root expression is the least common ancestor of all it's scope.
 *
 * Values are ordered by post-DFS order in each scope.
 *
382
 * \param e the expression to observably share.
383 384 385
 * \param mod The module used for referencing global functions, can be
 * None.
 *
386
 * \return expression in A-Normal Form.
387
 */
雾雨魔理沙 committed
388
TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
雾雨魔理沙 committed
389

390 391
/*!
 * \brief Remove let binding and directly share via pointer instead.
雾雨魔理沙 committed
392 393 394 395 396 397 398 399
 *
 * It will remove all let binding,
 * and turn all of the variable bound by let into direct pointer reference.
 *
 * \param e the expression.
 *
 * \return the expression in graph normal form.
 */
雾雨魔理沙 committed
400
TVM_DLL Expr ToGraphNormalForm(const Expr& e);
401

402 403 404
/*!
 * \brief Aggressive constant propagation/constant folding/inlining.
 *
雾雨魔理沙 committed
405 406 407
 * It will do as much computation in compile time as possible.
 * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
 * As a side effect, code size will explode.
408 409 410 411
 *
 * \param e the expression,
 *
 * \return the optimized expression.
雾雨魔理沙 committed
412
 */
413 414
TVM_DLL Expr PartialEval(const Expr& e);

415 416 417 418 419 420 421 422 423 424 425
/*!
 * \brief Bind the free variables to a Relay expression.
 *
 * \param expr The expression.
 * \param bind_map The variable to expression map that will be used to help the
 *        binding.
 *
 * \return The updated expression.
 */
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& bind_map);

426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
  /*! \brief Hash a Relay type.
   *
   * Implements structural hashing of a Relay type.
   *
   * \param type the type to hash.
   *
   * \return the hash value.
   */
  size_t operator()(const Type& type) const;

  /*! \brief Hash a Relay expression.
   *
   * Implements structural hashing of a Relay expression.
   *
   * \param expr the expression to hash.
   *
   * \return the hash value.
   */
  size_t operator()(const Expr& expr) const;
};
448 449 450

namespace vm {

451 452
/*!
 * \brief Compile a module, and construct the virtual machine.
453 454
 *
 * \param mod The module to compile.
455
 *
456 457 458 459 460 461
 * \return The constructed virtual machine.
 */
runtime::vm::VirtualMachine CompileModule(const Module& mod);

}  // namespace vm

462 463
}  // namespace relay
}  // namespace tvm
464

465
#endif  // TVM_RELAY_PASS_H_