ir_pass.h 11.8 KB
Newer Older
tqchen committed
1 2 3
/*!
 *  Copyright (c) 2016 by Contributors
 * \file ir_pass.h
tqchen committed
4 5
 * \brief Collection of IR pass functions
 *
6 7
 *  When the pass functions in this file are for Stmt,
 *  we can use PassFunction(Evaluate(expr)) to apply it to Expr
tqchen committed
8 9 10 11
 */
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_

12
#include <tvm/ir_functor.h>
13
#include <arithmetic/Simplify.h>
tqchen committed
14
#include <unordered_map>
tqchen committed
15
#include <vector>
16
#include <string>
tqchen committed
17
#include "./expr.h"
18
#include "./buffer.h"
tqchen committed
19
#include "./schedule.h"
20
#include "./lowered_func.h"
tqchen committed
21 22 23 24

namespace tvm {
namespace ir {

25 26 27 28 29 30 31
inline Expr Simplify(Expr a) {
  return Halide::Internal::simplify(a);
}

inline Stmt Simplify(Stmt a) {
  return Halide::Internal::simplify(a);
}
32 33

/*!
34 35 36 37 38 39 40 41 42 43 44 45 46 47
 * \brief Simplify by applying canonical form.
 * \param stmt The statement to be canonically simplifed.
 * \return Canonicalized statement.
 */
Stmt CanonicalSimplify(Stmt stmt);

/*!
 * \brief Simplify by applying canonical form.
 * \param expr The statement to be canonically simplifed.
 * \return Canonicalized expression.
 */
Expr CanonicalSimplify(Expr expr);

/*!
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
 * \brief Deep compare lhs and rhs
 * \param lhs The left operand
 * \param rhs The right operand
 * \return The comparison result.
 */
bool Equal(const Expr& lhs, const Expr& rhs);

/*!
 * \brief Deep compare lhs and rhs
 * \param lhs The left operand
 * \param rhs The right operand
 * \return The comparison result.
 */
bool Equal(const Stmt& lhs, const Stmt& rhs);

/*!
64 65 66 67 68 69 70 71 72 73 74 75 76
 * \brief Deep compare lhs and rhs.
 *
 *  If you only want equality comparison, use Equal
 *  which will also tie definitions. The compare mode
 *  will give order of expression in total order.
 *
 * \param lhs The left operand
 * \param rhs The right operand
 * \return The comparison result.
 */
int Compare(const Expr& lhs, const Expr& rhs);

/*!
tqchen committed
77 78 79 80 81 82
 * \brief verifies whether the IR stmt or Expr is in SSA form.
 *  That is: each VarExpr is defined and assigned once(in Let/For)
 *
 * \param ir The root of the IR DAG.
 * \return Whether IR is in SSA form.
 * \note All the passes in this file uses SSA form and outputs SSA form.
tqchen committed
83
 */
tqchen committed
84
bool VerifySSA(const Stmt& ir);
tqchen committed
85 86

/*!
87 88 89 90 91 92
 * \brief Whether the expression have side effect.
 * \return whether expression have side effect
 */
bool HasSideEffect(const Expr& e);

/*!
93 94 95 96 97 98 99 100
 * \brief Whether e expression used var.
 * \param e The expression to be checked.
 * \param v The variable.
 * \return Whether e uses v.
 */
bool ExprUseVar(const Expr& e, const Var& v);

/*!
101 102 103 104 105 106 107 108
 * \brief Whether e expression used any var in variable set..
 * \param e The expression to be checked.
 * \param vset The variable set.
 * \return Whether e uses vset.
 */
bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset);

/*!
tqchen committed
109 110 111 112
 * \brief Convert a IR node to be SSA form.
 * \param stmt The source statement to be converted.
 * \return The converted form.
 */
tqchen committed
113
Stmt ConvertSSA(Stmt stmt);
tqchen committed
114 115

/*!
116 117 118 119 120
 * \brief Substitute the var specified in key->var to be value.
 * \param stmt The source statement to be substituted
 * \param value_map The map of new values.
 * \return The converted form.
 */
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
Stmt Substitute(Stmt stmt,
                const std::unordered_map<const Variable*, Expr>& value_map);

/*!
 * \brief Substitute the var specified in key->var to be value.
 * \param expr The source expression to be substituted
 * \param value_map The map of new values.
 * \return The converted expression.
 */
Expr Substitute(Expr expr,
                const std::unordered_map<const Variable*, Expr>& value_map);

/*!
 * \brief Substitute the var specified in key->var to be value.
 * \param stmt The source statement to be substituted
 * \param value_map The map of new values.
 * \return The converted form.
 */
139
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
140 141 142 143 144 145 146

/*!
 * \brief Substitute the var specified in key->var to be value.
 * \param expr The source expression to be substituted
 * \param value_map The map of new values.
 * \return The converted expression.
 */
ziheng committed
147
Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
148 149

/*!
tqchen committed
150 151
 * \brief inline all calls of f in stmt.
 *
152
 * \param stmt The statement to apply inline optimization.
tqchen committed
153 154
 * \param f The function reference to be inlined
 * \param args The arguments variable of the function.
155
 * \param body The definition body of the function.
tqchen committed
156 157 158 159
 * \return The result stmt
 *
 * \note All the passes in this file uses SSA form and outputs SSA form.
 */
160 161
Stmt Inline(Stmt stmt,
            FunctionRef f,
tqchen committed
162
            Array<Var> args,
163 164 165 166 167 168 169 170 171
            Expr body);

/*!
 * \brief Flatten the multi-dimensional read/write
 *  to single dimensional Load/Store
 *
 * \param stmt The stmt to be trasnformed.
 * \param extern_buffer Map specifies external
 *    buffer assignment of input and outputs.
172
 * \param cache_line_size The size of CPU cache line.
173
 * \return Transformed stmt.
174 175
 */
Stmt StorageFlatten(Stmt stmt,
176 177
                    Map<Tensor, Buffer> extern_buffer,
                    int cache_line_size);
tqchen committed
178

179
/*!
Tianqi Chen committed
180 181 182 183 184 185 186 187 188
 * \brief Remove No Op from the Stmt.
 * \param stmt The stmt to be trasnformed
 * \return Transformed stmt.
 */
Stmt RemoveNoOp(Stmt stmt);

/*!
 * \brief Split statement into pipeine stages.
 * \param stmt The stmt to be splitted
189 190 191 192 193 194 195 196
 * \param split_load Whether split load into its own stage.
 * \return Transformed stmt.
 */
Stmt SplitPipeline(Stmt stmt, bool split_load);

/*!
 * \brief Narrow channel access to smaller range.
 * \param stmt The stmt to do access rewriting.
Tianqi Chen committed
197 198
 * \return Transformed stmt.
 */
199
Stmt NarrowChannelAccess(Stmt stmt);
Tianqi Chen committed
200 201

/*!
202 203 204
 * \brief unroll the constant loop marked by unroll.
 * This pass also automatically attach pragma unroll tag to loops which meets the standard.
 *
205
 * \param stmt The statment to be unrolled.
206 207 208
 * \param auto_max_step The maximum step before stop attach automatic unroll
 * \param auto_min_depth The minimum depth before we can start automatic unroll
 * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
209
 * \return Transformed stmt.
210
 */
211
Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_min_depth, bool explicit_unroll);
212 213

/*!
214 215
 * \brief vectorize the constant loops
 * \param stmt The statment to be vectorized.
216
 * \return Transformed stmt.
217 218 219 220
 */
Stmt VectorizeLoop(Stmt stmt);

/*!
221 222 223 224 225 226 227
 * \brief Inject virtual thread loops into stmt.
 * \param stmt The statment to be transformed.
 * \return Transformed stmt.
 */
Stmt InjectVirtualThread(Stmt stmt);

/*!
228 229 230 231 232 233 234
 * \brief Inject prefetch instructions into stmt.
 * \param stmt The statment to be transformed.
 * \return Transformed stmt.
 */
Stmt InjectPrefetch(Stmt stmt);

/*!
235 236
 * \brief Inject double buffer into stmt.
 * \param stmt The statment to be transformed.
237
 * \param split_loop Loop splitting factor.
238 239
 * \return Transformed stmt.
 */
240
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
241 242

/*!
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
 * \brief Inject copy intrinsics with optional pad.
 *
 * \param stmt The statment to be transformed.
 * \param pragma_key The pragma key for hint of copy.
 * \param fintrin The function with signature
 *
 *   Stmt fintrin(Buffer src,
 *                Buffer dst,
 *                Array<Expr> pad_before,
 *                Array<Expr> pad_after,
 *                Expr pad_value)
 * \return Transformed stmt.
 */
Stmt InjectCopyIntrin(Stmt stmt,
                      const std::string& pragma_key,
                      const runtime::PackedFunc& fintrin);

/*!
261 262 263 264
 * \brief Rewrite storage allocation pattern.
 *  Moves the allocation to outer most possible scope.
 *  Trying to share space between allocations to make
 *  a static allocation plan when possible.
265 266 267 268
 *
 * \param stmt The stmt to be trasnformed
 * \return Transformed stmt.
 */
269
Stmt StorageRewrite(Stmt stmt);
270 271

/*!
272 273 274 275 276 277 278
 * \brief partition loops in the stmt
 * \param stmt The stmt to do loop partition
 * \return Transformed stmt.
 */
Stmt LoopPartition(Stmt stmt);

/*!
279 280 281 282 283 284 285 286
 * \brief Detect and insert sync points to co-processor.
 *
 * \param stmt The stmt to be trasnformed
 * \return Transformed stmt.
 */
Stmt CoProcSync(Stmt stmt);

/*!
287 288 289 290 291 292 293 294 295
 * \brief Lift common attrs with attr_key to outer scope.
 *
 * \param stmt The stmt to be trasnformed
 * \param attr_key The attribute key to be checked.
 * \return Transformed stmt.
 */
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);

/*!
296 297 298 299 300 301 302
 * \brief Detect and rewrite unsafe select that contains memory access.
 * \param stmt The statment to be rewritten.
 * \return Transformed stmt.
 */
Stmt RewriteUnsafeSelect(Stmt stmt);

/*!
303 304 305 306 307 308 309 310 311
 * \brief Lower attached storage access information.
 * Do this pass after all storage access analysis finish.
 *
 * \param stmt The stmt to be trasnformed
 * \return Transformed stmt.
 */
Stmt LowerStorageAccessInfo(Stmt stmt);

/*!
312 313 314 315 316 317 318 319 320
 * \brief Make an user callable API LoweredFunc.
 *
 *  The main task of this function is to create code to :
 *   - Map the values in the api_args to of Var that is required by body.
 *   - Insert assertions to check type/value of the passed arguments.
 *
 * \param body The body of the function.
 * \param name The name of the function.
 * \param api_args Arguments to the function, can be either Var, or Buffer
321 322
 * \param num_unpacked_args Number of arguments that
 *         are processed in plain form instead of packed form.
323 324 325
 * \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap.
 *  It is recommended to set to true for optimized code if such invariant holds.
 *
326 327 328 329 330
 * \return a LoweredFunc with the specified signiture.
 *
 * \note
 *  The function signiture have two cases
 *
331 332
 *  let num_packed_args = len(api_args) - num_unpacked_args;
 *
333 334 335 336 337 338 339 340 341 342 343 344 345 346
 *  if num_packed_args is zero:
 *     f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
 *
 *  if num_packed_args is not zero:
 *       f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
 *         api_arg_k, api_arg_k+1, ... api_arg_n)
 *
 *       where n == len(api_args), k == num_packed_args
 *
 *  There is no thread_axis in generated function.
 */
LoweredFunc MakeAPI(Stmt body,
                    std::string name,
                    Array<NodeRef> api_args,
347 348
                    int num_unpacked_args,
                    bool is_restricted);
349 350

/*!
351 352 353 354 355 356 357 358
 * \brief Bind the device type of host function to be device_type.
 * \param func The function to be binded.
 * \param device_type The device type to be binded.
 * \return The binded function.
 */
LoweredFunc BindDeviceType(LoweredFunc func,
                           int device_type);
/*!
359 360 361 362
 * \brief Find undefined vars in the statment.
 * \param stmt The function to be checked.
 * \param defs The vars that is defined.
 * \return Array of undefined vars.
363
 */
364
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380

/*!
 * \brief Split the function into a host function and device functions.
 * \param func The function to be splitted.
 *
 * \return Array of functions, the first one is host function,
 *     the others are device functions.
 */
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);

/*!
 * \brief Insert sync between parallel read/write of shared buffers.
 *
 * \param stmt The stmt to be trasnformed.
 * \param storage_scope The storage scope considered.
 */
381
LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);
382

383 384 385 386 387 388 389
/*!
 * \brief Lower cross thread alleduce in the stmt.
 * \param f The device function to be lowered.
 * \param warp_size the size of warp where no sync is needed.
 * \return Transformed function.
 */
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
390 391

/*!
392 393 394 395
 * \brief Lower packed function call.
 * \param f The function to be lowered.
 * \return Transformed function.
 */
396
LoweredFunc LowerTVMBuiltin(LoweredFunc f);
397 398

/*!
399 400 401 402 403 404 405
 * \brief Combine context function calls.
 * \param f The host function to be lowered.
 * \return Transformed function.
 */
LoweredFunc CombineContextCall(LoweredFunc f);

/*!
406 407 408 409 410 411
 * \brief Lower intrinsic function calls.
 * \param f The device function to be lowered.
 * \param target The target device.
 * \return Transformed function.
 */
LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
tqchen committed
412 413 414 415
}  // namespace ir
}  // namespace tvm

#endif  // TVM_IR_PASS_H_