ir_pass.h 17.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

tqchen committed
20
/*!
tqchen committed
21
 * \file tvm/ir_pass.h
tqchen committed
22 23
 * \brief Collection of IR pass functions
 *
24 25
 *  When the pass functions in this file are for Stmt,
 *  we can use PassFunction(Evaluate(expr)) to apply it to Expr
tqchen committed
26 27 28 29 30
 */
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_

#include <unordered_map>
31
#include <unordered_set>
tqchen committed
32
#include <vector>
33
#include <string>
34 35 36 37
#include "expr.h"
#include "buffer.h"
#include "schedule.h"
#include "lowered_func.h"
tqchen committed
38 39 40 41

namespace tvm {
namespace ir {

42 43 44 45 46 47
/*!
 * \brief Simplify the expression.
 * \param expr The expression to be simplifed.
 * \param vrange The range information about the variable.
 * \return Canonicalized statement.
 */
48
TVM_DLL Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>());
49

50 51 52 53 54 55 56
/*!
 * \brief Simplify the statement.
 * \param stmt The statement to be simplifed.
 * \param vrange The range information about the variable.
 * \return Canonicalized statement.
 */
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
57 58

/*!
59 60
 * \brief Simplify by applying canonical form.
 * \param stmt The statement to be canonically simplifed.
61
 * \param vrange The range information about the variable.
62 63
 * \return Canonicalized statement.
 */
64 65
Stmt CanonicalSimplify(Stmt stmt,
                       Map<Var, Range> vrange = Map<Var, Range>());
66 67 68 69

/*!
 * \brief Simplify by applying canonical form.
 * \param expr The statement to be canonically simplifed.
70
 * \param vrange The range information about the variable.
71 72
 * \return Canonicalized expression.
 */
73
TVM_DLL Expr CanonicalSimplify(Expr expr,
74
                              Map<Var, Range> vrange = Map<Var, Range>());
75 76

/*!
77 78 79 80 81
 * \brief Deep compare lhs and rhs
 * \param lhs The left operand
 * \param rhs The right operand
 * \return The comparison result.
 */
82
TVM_DLL bool Equal(const Expr& lhs, const Expr& rhs);
83 84 85 86 87 88 89 90 91 92

/*!
 * \brief Deep compare lhs and rhs
 * \param lhs The left operand
 * \param rhs The right operand
 * \return The comparison result.
 */
bool Equal(const Stmt& lhs, const Stmt& rhs);

/*!
93 94 95 96 97 98 99 100 101 102 103 104 105
 * \brief Deep compare lhs and rhs.
 *
 *  If you only want equality comparison, use Equal
 *  which will also tie definitions. The compare mode
 *  will give order of expression in total order.
 *
 * \param lhs The left operand
 * \param rhs The right operand
 * \return The comparison result.
 */
int Compare(const Expr& lhs, const Expr& rhs);

/*!
tqchen committed
106 107 108 109 110 111
 * \brief verifies whether the IR stmt or Expr is in SSA form.
 *  That is: each VarExpr is defined and assigned once(in Let/For)
 *
 * \param ir The root of the IR DAG.
 * \return Whether IR is in SSA form.
 * \note All the passes in this file uses SSA form and outputs SSA form.
tqchen committed
112
 */
113
TVM_DLL bool VerifySSA(const Stmt& ir);
tqchen committed
114 115

/*!
116 117 118
 * \brief Whether the expression have side effect.
 * \return whether expression have side effect
 */
119
TVM_DLL bool HasSideEffect(const Expr& e);
120 121

/*!
122 123 124 125 126 127 128 129
 * \brief Whether e expression used var.
 * \param e The expression to be checked.
 * \param v The variable.
 * \return Whether e uses v.
 */
bool ExprUseVar(const Expr& e, const Var& v);

/*!
130 131 132 133 134 135 136 137
 * \brief Whether e expression used any var in variable set..
 * \param e The expression to be checked.
 * \param vset The variable set.
 * \return Whether e uses vset.
 */
bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset);

/*!
tqchen committed
138 139 140 141
 * \brief Convert a IR node to be SSA form.
 * \param stmt The source statement to be converted.
 * \return The converted form.
 */
142
TVM_DLL Stmt ConvertSSA(Stmt stmt);
tqchen committed
143 144

/*!
145 146 147 148 149
 * \brief Substitute the var specified in key->var to be value.
 * \param stmt The source statement to be substituted
 * \param value_map The map of new values.
 * \return The converted form.
 */
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
Stmt Substitute(Stmt stmt,
                const std::unordered_map<const Variable*, Expr>& value_map);

/*!
 * \brief Substitute the var specified in key->var to be value.
 * \param expr The source expression to be substituted
 * \param value_map The map of new values.
 * \return The converted expression.
 */
Expr Substitute(Expr expr,
                const std::unordered_map<const Variable*, Expr>& value_map);

/*!
 * \brief Substitute the var specified in key->var to be value.
 * \param stmt The source statement to be substituted
 * \param value_map The map of new values.
 * \return The converted form.
 */
168
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
169 170 171 172 173 174 175

/*!
 * \brief Substitute the var specified in key->var to be value.
 * \param expr The source expression to be substituted
 * \param value_map The map of new values.
 * \return The converted expression.
 */
ziheng committed
176
Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
177 178

/*!
tqchen committed
179 180
 * \brief inline all calls of f in stmt.
 *
181
 * \param stmt The statement to apply inline optimization.
tqchen committed
182 183
 * \param f The function reference to be inlined
 * \param args The arguments variable of the function.
184
 * \param body The definition body of the function.
tqchen committed
185 186 187 188
 * \return The result stmt
 *
 * \note All the passes in this file uses SSA form and outputs SSA form.
 */
189 190
Stmt Inline(Stmt stmt,
            FunctionRef f,
tqchen committed
191
            Array<Var> args,
192 193 194 195 196 197 198 199 200
            Expr body);

/*!
 * \brief Flatten the multi-dimensional read/write
 *  to single dimensional Load/Store
 *
 * \param stmt The stmt to be trasnformed.
 * \param extern_buffer Map specifies external
 *    buffer assignment of input and outputs.
201
 * \param cache_line_size The size of CPU cache line.
202
 * \param create_bound_attribute Whether to create bound attributes.
203
 * \return Transformed stmt.
204 205
 */
Stmt StorageFlatten(Stmt stmt,
206
                    Map<Tensor, Buffer> extern_buffer,
207 208
                    int cache_line_size,
                    bool create_bound_attribute = false);
tqchen committed
209

210
/*!
Tianqi Chen committed
211 212 213 214 215 216 217 218 219
 * \brief Remove No Op from the Stmt.
 * \param stmt The stmt to be trasnformed
 * \return Transformed stmt.
 */
Stmt RemoveNoOp(Stmt stmt);

/*!
 * \brief Split statement into pipeine stages.
 * \param stmt The stmt to be splitted
220 221 222 223 224 225 226 227
 * \param split_load Whether split load into its own stage.
 * \return Transformed stmt.
 */
Stmt SplitPipeline(Stmt stmt, bool split_load);

/*!
 * \brief Narrow channel access to smaller range.
 * \param stmt The stmt to do access rewriting.
Tianqi Chen committed
228 229
 * \return Transformed stmt.
 */
230
Stmt NarrowChannelAccess(Stmt stmt);
Tianqi Chen committed
231 232

/*!
233 234 235
 * \brief unroll the constant loop marked by unroll.
 * This pass also automatically attach pragma unroll tag to loops which meets the standard.
 *
236
 * \param stmt The statment to be unrolled.
237
 * \param auto_max_step The maximum step before stop attach automatic unroll
238
 * \param auto_max_depth The maximum depth before stop attach automatic unroll
239
 * \param auto_max_extent The maximum extent of the loop we can unroll,
Siju committed
240
 *                     this is an legacy option that do not take the loop total steps into account.
241
 * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
242
 * \return Transformed stmt.
243
 */
244 245
Stmt UnrollLoop(Stmt stmt,
                int auto_max_step,
246
                int auto_max_depth,
247 248
                int auto_max_extent,
                bool explicit_unroll);
249 250

/*!
251
 * \brief vectorize the constant loops
252
 * \param stmt The statement to be vectorized.
253
 * \return Transformed stmt.
254 255 256 257
 */
Stmt VectorizeLoop(Stmt stmt);

/*!
258 259 260 261 262 263 264
 * \brief convert vectorized loops into serialized loops
 * \param stmt The statement to skip vectorization on.
 * \return Transformed stmt.
 */
Stmt SkipVectorize(Stmt stmt);

/*!
265
* \brief instruments bound checkers.
266 267
* \param stmt The statement to be instrumented.
* \return Instrumented stmt.
268 269 270 271
*/
Stmt InstrumentBoundCheckers(Stmt stmt);

/*!
272
 * \brief Inject virtual thread loops into stmt.
273
 * \param stmt The statement to be transformed.
274 275 276 277 278
 * \return Transformed stmt.
 */
Stmt InjectVirtualThread(Stmt stmt);

/*!
279
 * \brief Inject prefetch instructions into stmt.
280
 * \param stmt The statement to be transformed.
281 282 283 284 285
 * \return Transformed stmt.
 */
Stmt InjectPrefetch(Stmt stmt);

/*!
286
 * \brief Inject double buffer into stmt.
287
 * \param stmt The statement to be transformed.
288
 * \param split_loop Loop splitting factor.
289 290
 * \return Transformed stmt.
 */
291
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
292 293

/*!
294 295
 * \brief Inject copy intrinsics with optional pad.
 *
296
 * \param stmt The statement to be transformed.
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
 * \param pragma_key The pragma key for hint of copy.
 * \param fintrin The function with signature
 *
 *   Stmt fintrin(Buffer src,
 *                Buffer dst,
 *                Array<Expr> pad_before,
 *                Array<Expr> pad_after,
 *                Expr pad_value)
 * \return Transformed stmt.
 */
Stmt InjectCopyIntrin(Stmt stmt,
                      const std::string& pragma_key,
                      const runtime::PackedFunc& fintrin);

/*!
312 313 314 315
 * \brief Rewrite storage allocation pattern.
 *  Moves the allocation to outer most possible scope.
 *  Trying to share space between allocations to make
 *  a static allocation plan when possible.
316
 *
317
 * \param stmt The stmt to be transformed
318 319
 * \return Transformed stmt.
 */
320
Stmt StorageRewrite(Stmt stmt);
321 322

/*!
323 324
 * \brief partition loops in the stmt
 * \param stmt The stmt to do loop partition
325
 * \param split_const_loop flag to enable partition for const loop
326 327
 * \return Transformed stmt.
 */
328
Stmt LoopPartition(Stmt stmt, bool split_const_loop);
329 330

/*!
331 332
 * \brief Detect and insert sync points to co-processor.
 *
333
 * \param stmt The stmt to be transformed
334 335 336 337 338
 * \return Transformed stmt.
 */
Stmt CoProcSync(Stmt stmt);

/*!
339 340
 * \brief Lift common attrs with attr_key to outer scope.
 *
341
 * \param stmt The stmt to be transformed
342 343 344 345 346 347
 * \param attr_key The attribute key to be checked.
 * \return Transformed stmt.
 */
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);

/*!
348
 * \brief Detect and rewrite unsafe select that contains memory access.
349
 * \param stmt The statement to be rewritten.
350 351 352 353 354
 * \return Transformed stmt.
 */
Stmt RewriteUnsafeSelect(Stmt stmt);

/*!
355 356 357
 * \brief Lower attached storage access information.
 * Do this pass after all storage access analysis finish.
 *
358
 * \param stmt The stmt to be transformed
359 360 361 362 363
 * \return Transformed stmt.
 */
Stmt LowerStorageAccessInfo(Stmt stmt);

/*!
364
 * \brief Decorate the stmt with a device scope, this is helpful for
365 366
 * hardware accelerator without thread blocks.
 *
367
 * \param stmt The stmt to be transformed
368 369 370 371 372
 * \return Transformed stmt.
 */
Stmt DecorateDeviceScope(Stmt stmt);

/*!
373 374 375 376 377 378 379 380 381
 * \brief Make an user callable API LoweredFunc.
 *
 *  The main task of this function is to create code to :
 *   - Map the values in the api_args to of Var that is required by body.
 *   - Insert assertions to check type/value of the passed arguments.
 *
 * \param body The body of the function.
 * \param name The name of the function.
 * \param api_args Arguments to the function, can be either Var, or Buffer
382 383
 * \param num_unpacked_args Number of arguments that
 *         are processed in plain form instead of packed form.
384 385 386
 * \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap.
 *  It is recommended to set to true for optimized code if such invariant holds.
 *
387 388 389
 * \return a LoweredFunc with the specified signiture.
 *
 * \note
390
 *  The function signature have two cases
391
 *
392 393
 *  let num_packed_args = len(api_args) - num_unpacked_args;
 *
394 395 396 397 398 399 400 401 402 403 404 405 406 407
 *  if num_packed_args is zero:
 *     f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
 *
 *  if num_packed_args is not zero:
 *       f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
 *         api_arg_k, api_arg_k+1, ... api_arg_n)
 *
 *       where n == len(api_args), k == num_packed_args
 *
 *  There is no thread_axis in generated function.
 */
LoweredFunc MakeAPI(Stmt body,
                    std::string name,
                    Array<NodeRef> api_args,
408 409
                    int num_unpacked_args,
                    bool is_restricted);
410 411

/*!
412 413 414 415 416 417 418 419
 * \brief Bind the device type of host function to be device_type.
 * \param func The function to be binded.
 * \param device_type The device type to be binded.
 * \return The binded function.
 */
LoweredFunc BindDeviceType(LoweredFunc func,
                           int device_type);
/*!
420 421 422 423
 * \brief Find undefined vars in the statment.
 * \param stmt The function to be checked.
 * \param defs The vars that is defined.
 * \return Array of undefined vars.
424
 */
425
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441

/*!
 * \brief Split the function into a host function and device functions.
 * \param func The function to be splitted.
 *
 * \return Array of functions, the first one is host function,
 *     the others are device functions.
 */
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);

/*!
 * \brief Insert sync between parallel read/write of shared buffers.
 *
 * \param stmt The stmt to be trasnformed.
 * \param storage_scope The storage scope considered.
 */
442
LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);
443

444 445 446 447 448 449 450
/*!
 * \brief Lower cross thread alleduce in the stmt.
 * \param f The device function to be lowered.
 * \param warp_size the size of warp where no sync is needed.
 * \return Transformed function.
 */
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
451 452

/*!
453 454 455 456 457 458 459 460 461
 * \brief Lower warp memory in stmt.
 * \param f The device function to be lowered.
 * \param warp_size the size of warp where no sync is needed.
 *        this function will only take in effect if warp_size is bigger than one.
 * \return Transformed function.
 */
LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);

/*!
462 463 464 465 466 467 468 469 470 471 472 473 474 475
 * \brief Remap the thread axis
 *
 *  This can be used to get equivalent program which uses
 *  threadIdx.y in place of threadIdx.x by passing
 *  {"threadIdx.x": thread_axis("threadIdx.y")}
 *
 *
 * \param f The device function to be lowered.
 * \param axis_map The map from StringImm -> ItrVar
 * \return Transformed function.
 */
LoweredFunc RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> axis_map);

/*!
476 477 478 479
 * \brief Lower packed function call.
 * \param f The function to be lowered.
 * \return Transformed function.
 */
480
LoweredFunc LowerTVMBuiltin(LoweredFunc f);
481 482

/*!
483 484 485 486 487 488 489
 * \brief Combine context function calls.
 * \param f The host function to be lowered.
 * \return Transformed function.
 */
LoweredFunc CombineContextCall(LoweredFunc f);

/*!
490 491 492 493 494 495 496 497 498 499 500 501
 * \brief Rewrite the pointer content type of arguments,
 *  as well as Alloc internal to the function to use
 *  the most frequently accessed type for load/store
 *  to avoid pointer casting in backend when possible.
 *
 * \note implemeneted in storage_rewrite.cc
 * \param f The function to be trasnformed
 * \return Transformed function.
 */
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);

/*!
502 503 504 505 506 507
 * \brief Lower intrinsic function calls.
 * \param f The device function to be lowered.
 * \param target The target device.
 * \return Transformed function.
 */
LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
508 509

/*!
510 511 512 513 514 515 516 517 518 519 520
 * \brief Lower custom datatypes.
 *
 * See tvm::datatypes::Registry for more information on adding custom datatypes.
 *
 * \param f The device function to be lowered.
 * \param target The target device.
 * \return Transformed function.
 */
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);

/*!
521 522 523 524 525 526 527 528 529 530 531 532
 * \brief Verify if memory accesses are legal for a specific target device type.
 *
 *  In the case that tgt is cuda, if not all workload is bound with
 *  threads, CPU code is generated that tries to access GPU memory,
 *  which is illegal. This pass performs verification for this case.
 *
 * \param func The function to be verified.
 * \param device_type The target device type.
 * \return Success of memory verification.
 */
bool VerifyMemory(LoweredFunc func, int device_type);

533 534 535 536 537 538 539 540 541 542 543

/*!
 * \brief Verify the correctness of a GPU code
 *        It will check the whether the amount of memory usage or the number of threads
 *        in a block exceeds the limit
 * \param stmt The statement to be checked
 * \param constraints The dict to specify constraints to check.
 *        Possible keys are
 *
 *        "max_local_memory_per_block": Total amount of local memory per block (in bytes).
 *        "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
544
 *        "max_threads_per_block": Maximum number of threads per block.
545 546 547 548 549 550 551 552 553 554 555 556
 *        "max_thread_x": Maximum length of threadIdx.x.
 *        "max_thread_y": Maximum length of threadIdx.y.
 *        "max_thread_z": Maximum length of threadIdx.z.
 *
 *        If one key is missing in this argument, the pass won't check for that item.
 * \return valid Whether it is a valid GPU code
 *
 */
bool VerifyGPUCode(Stmt stmt,
                   Map<std::string, Expr> constraints);


tqchen committed
557 558 559 560
}  // namespace ir
}  // namespace tvm

#endif  // TVM_IR_PASS_H_