pass.h 9.21 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2018 by Contributors
 * \file tvm/relay/pass.h
 * \brief The set of Relay passes written in C++.
 */
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

9
#include <tvm/relay/module.h>
10
#include <tvm/relay/expr.h>
11
#include <tvm/relay/op_attr_types.h>
12
#include <string>
13 14 15 16

namespace tvm {
namespace relay {

17 18
/*!
 * \brief Infer the type of an expression.
19 20 21 22 23
 *
 * The result of type checking is a new expression with unambigous
 * type information filled in, as well as it's checked type field
 * populated with the result type.
 *
24
 * \param expr The expression to type check.
25
 * \param mod The module used for referencing global functions, can be
26
 * None.
27 28 29
 *
 * \return A type checked expression with its checked_type field populated.
 */
30
Expr InferType(const Expr& expr, const Module& mod);
31

32
/*!
33
 * \brief Infer the type of a function as if it is mapped to var in the mod.
34 35
 *
 * \param f the function.
36
 * \param mod The module used for referencing global functions.
37 38 39
 * \param var The global variable corresponding to the function.
 *
 * \return A type checked Function with its checked_type field populated.
40
 * \note this function mutates mod and is not thread-safe.
41
 */
42
Function InferType(const Function& f, const Module& mod,
43
                   const GlobalVar& var);
44 45

/*!
46
 * \brief Check that types are well kinded by applying "kinding rules".
47 48 49 50 51 52 53 54 55 56
 *
 * This pass ensures we do not do things that violate the design of the
 * type system when writing down types.
 *
 * For example tensors are not allowed to contain functions in Relay.
 *
 * We check this by ensuring the `dtype` field of a Tensor always contains
 * a data type such as `int`, `float`, `uint`.
 *
 * \param t The type to check.
57
 * \param mod The global module.
58
 *
59 60
 * \return true if the rules are satisified otherwise false
 */
61
bool KindCheck(const Type& t, const Module& mod);
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

/*! \brief Compare two expressions for structural equivalence.
 *
 * This comparison operator respects scoping and compares
 * expressions without regard to variable choice.
 *
 * For example: `let x = 1 in x` is equal to `let y = 1 in y`.
 *
 *   See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
 *   for more details.
 *
 *   \param e1 The left hand expression.
 *   \param e2 The right hand expression.
 *
 *   \return true if equal, otherwise false
 */
bool AlphaEqual(const Expr& e1, const Expr& e2);

/*! \brief Compare two types for structural equivalence.
 *
 * This comparison operator respects scoping and compares
 * expressions without regard to variable choice.
 *
 * For example: `forall s, Tensor[f32, s]` is equal to
 * `forall w, Tensor[f32, w]`.
87
 *
88 89 90 91 92 93 94 95 96 97
 * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
 * for more details.
 *
 * \param t1 The left hand type.
 * \param t2 The right hand type.
 *
 * \return true if equal, otherwise false
 */
bool AlphaEqual(const Type& t1, const Type& t2);

98
/*! \brief Check that each Var is only bound once.
99 100 101
 *
 * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
 *
102 103
 * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice,
 * although x is not shadowed.
104
 *
105
  * \param expr the expression to check.
106
 *
107
  * \return true iff all Var in expr is bound at most once.
108
 */
109
bool WellFormed(const Expr& expr);
110

111 112 113 114 115 116 117 118 119 120 121
/*! \brief Get all bound variables from expression expr.
 *
 * Bound variables are all variables that are declared in the expr.
 * They only have meaning inside that expr, and can only be used in it.
 *
 * \param expr the expression.
 *
 * \return List of bound vars, in the PostDFS order in the expression.
 */
tvm::Array<Var> BoundVars(const Expr& expr);

122
/*! \brief Get free type parameters from expression expr.
123
 *
124 125
 * Free variables are variables that are not bound by a
 * let or a function parameter in the context.
126
 *
127
 * \param expr the expression.
128
 *
129
 * \return List of free vars, in the PostDFS order in the expression.
130
 */
131
tvm::Array<Var> FreeVars(const Expr& expr);
132

133 134 135 136 137 138 139 140
/*! \brief Get all variables from expression expr.
 *
 * \param expr the expression.
 *
 * \return List of all vars, in the PostDFS order in the expression.
 */
tvm::Array<Var> AllVars(const Expr& expr);

141
/*! \brief Get free TypeVars from expression expr.
142
 *
143 144
 * Free type parameters are type parameters that are not bound by a function
 * type in the context.
145
 *
146
 * \param expr the expression.
147
 *
148
 * \return List of free vars, in the PostDFS order visited by expr.
149
 */
150
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
151

152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
/*! \brief Get free TypeVars from type t.
 *
 * Free type parameters are type parameters that are not bound by a function
 * type in the context.
 *
 * \param t the type.
 *
 * \return List of free type vars, in the PostDFS order visited by type.
 */
tvm::Array<TypeVar> FreeTypeVars(const Type& t);

/*! \brief Get all bound type variables from expression expr.
 *
 * Bound variables are all type variables that are declared in the expr.
 * They only have meaning inside that expr, and can only be used in it.
 *
 * \param expr the expression.
 *
 * \return List of bound type vars, in the PostDFS order in the expression.
 */
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr);

/*! \brief Get all bound type variables from type t.
 *
 * Bound variables are all type variables that are declared in the type.
 * They only have meaning inside that type, and can only be used in it.
 *
 * \param t the type
 *
 * \return List of bound type vars, in the PostDFS order visited by type.
 */
tvm::Array<TypeVar> BoundTypeVars(const Type& t);

/*! \brief Get all type variables in expression expr.
 *
 * \param expr the expression.
 *
 * \return List of type vars, in the PostDFS order in the expression.
 */
tvm::Array<TypeVar> AllTypeVars(const Expr& expr);

/*! \brief Get all type variables in type t.
 *
 * \param t the type.
 *
 * \return List of type vars, in the PostDFS order visited by type.
 */
tvm::Array<TypeVar> AllTypeVars(const Type& t);

201 202
/*! \brief Remove expressions which does not effect the program result.
 *
203 204
 * It will remove let bindings which are not referenced, and branches that will
 * not be entered.
205
 *
206 207 208
 * For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
 * the expression does not depend on a. Another example is `if (true) then 1
 * else 2` will be optimized into 1.
209 210 211 212 213 214
 *
 * \param e the expression to optimize.
 *
 * \return the optimized expression.
 */
Expr DeadCodeElimination(const Expr& e);
215

216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
/*!
 * \brief Fold constant expressions.
 * \param expr the expression to be optimized.
 * \return The optimized expression.
 */
Expr FoldConstant(const Expr& expr);

/*!
 * \brief Fuse operations into expr into seperate functions.
 * \param expr The expression.
 * \param fuse_opt_level Optimization level.
 * \return The optimized expression.
 */
Expr FuseOps(const Expr& expr, int fuse_opt_level);

231 232 233 234 235 236
/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order.
 * \param expr The expression.
 * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
 *                              rule function.
 * \param fcontext Additional callback to provide context argument for each call node.
237 238
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
239 240 241 242
 * \return The rewritten expression.
 */
Expr ForwardRewrite(const Expr& expr,
                    const std::string& rewrite_map_attr_name,
243 244
                    std::function<NodeRef(const Call&)> fcontext = nullptr,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
245

246 247 248 249 250 251 252 253 254 255 256 257 258 259
/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order.
 * \param expr The expression.
 * \param rewrite_func The rewrite func that will apply to all operators.
 * \param fcontext Additional callback to provide context argument for each call node.
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
 * \return The rewritten expression.
 */
Expr ForwardRewrite(const Expr& expr,
                    const FForwardRewrite& rewrite_func,
                    std::function<NodeRef(const Call&)> fcontext = nullptr,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
/*!
 * \brief Rewrite the annotated program.
 * \param expr The expression.
 * \param fallback_device The fallback device which is the default device for
 *                        operators without annotation.
 * \return The updated program.
 */
Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);

/*!
 * \brief Collect the device mapping information of each expression.
 * \param expr The expression.
 * \return The device mapping.
 */
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
275

276 277 278 279 280 281 282 283 284 285 286
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
  /*! \brief Hash a Relay type.
   *
   * Implements structural hashing of a Relay type.
   *
   *  \param type the type to hash.
   *
   *  \return the hash value.
   */
  size_t operator()(const Type& type) const;
287

288 289 290 291 292 293 294 295 296 297
  /*! \brief Hash a Relay expression.
   *
   * Implements structural hashing of a Relay expression.
   *
   * \param expr the expression to hash.
   *
   * \return the hash value.
   */
  size_t operator()(const Expr& expr) const;
};
298

299 300
}  // namespace relay
}  // namespace tvm
301

302
#endif  // TVM_RELAY_PASS_H_